mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-10 21:07:27 +00:00
Compare commits
149 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6192aebe66 | ||
|
|
a85a594597 | ||
|
|
014c9450ba | ||
|
|
63640f65e8 | ||
|
|
fd040988a3 | ||
|
|
f7c3b043b5 | ||
|
|
93e7675bc3 | ||
|
|
d7c97d4d34 | ||
|
|
dce794dbf7 | ||
|
|
093d86040f | ||
|
|
39617bc8c6 | ||
|
|
7da224ba92 | ||
|
|
df862732df | ||
|
|
fd4447f60a | ||
|
|
ea79d59aa0 | ||
|
|
41b0cf406c | ||
|
|
ef32cc8e0a | ||
|
|
ee8956b0e9 | ||
|
|
5ad9f8d931 | ||
|
|
ea379e1d0e | ||
|
|
b842baf21f | ||
|
|
58c9c7d5dd | ||
|
|
384fadf227 | ||
|
|
e4def0625b | ||
|
|
44d20de251 | ||
|
|
7ea33c2ddf | ||
|
|
b43423bffc | ||
|
|
cf4700a35c | ||
|
|
6bb552128c | ||
|
|
50b4fc06f8 | ||
|
|
f7f1be9df2 | ||
|
|
59574dc80f | ||
|
|
7577ec1ac4 | ||
|
|
d487be0029 | ||
|
|
83a3872b97 | ||
|
|
1ad2f63f85 | ||
|
|
fcaa8317e4 | ||
|
|
ccda14255a | ||
|
|
8d66828229 | ||
|
|
4ebf9e35e1 | ||
|
|
2902d6c7c2 | ||
|
|
01ef1fe4e4 | ||
|
|
c3d2d07b68 | ||
|
|
18417bacb3 | ||
|
|
8ec18dd21b | ||
|
|
edaff1c689 | ||
|
|
9c3a13cb23 | ||
|
|
0b326e7af4 | ||
|
|
1a1ff836b5 | ||
|
|
34fed74f64 | ||
|
|
f89b29928c | ||
|
|
2c6d4460c3 | ||
|
|
7afd3f97ee | ||
|
|
0708452939 | ||
|
|
a9e5d99ea3 | ||
|
|
a56d9ea98b | ||
|
|
f5e80af0b3 | ||
|
|
a1a7ddbc83 | ||
|
|
8b209d8926 | ||
|
|
9344cab59a | ||
|
|
03468e05e4 | ||
|
|
11792ba1a4 | ||
|
|
5baaa06896 | ||
|
|
d3286893c4 | ||
|
|
b087b20bac | ||
|
|
6a5a839d4d | ||
|
|
5d8a0952b4 | ||
|
|
bd08ecc1e0 | ||
|
|
e4f61c1084 | ||
|
|
a38215478f | ||
|
|
c192d07a04 | ||
|
|
098880b796 | ||
|
|
150c506ece | ||
|
|
f978d8224e | ||
|
|
ab59887933 | ||
|
|
458472f3e2 | ||
|
|
a9f98c5d39 | ||
|
|
2b7dff2d94 | ||
|
|
58752d2dcf | ||
|
|
67546f4b2a | ||
|
|
8e9dae7b5f | ||
|
|
fb4ff63bad | ||
|
|
1fed1ee567 | ||
|
|
02571c20ff | ||
|
|
8201daa4b4 | ||
|
|
5b54624cd5 | ||
|
|
db737567fb | ||
|
|
5c3898d13e | ||
|
|
2fdb2be6d0 | ||
|
|
ab78efc815 | ||
|
|
fcf97d1796 | ||
|
|
e85cc6acbe | ||
|
|
da002e6ca9 | ||
|
|
070e7b6911 | ||
|
|
d5a3eb7d04 | ||
|
|
616e6953cc | ||
|
|
b7c77777a5 | ||
|
|
8a79de333a | ||
|
|
a87d4271d3 | ||
|
|
7975cdf3bf | ||
|
|
b2badad554 | ||
|
|
133d8c9f77 | ||
|
|
9708d645d3 | ||
|
|
0bca4d3efc | ||
|
|
7572e791f6 | ||
|
|
16c63b3be9 | ||
|
|
37fbcb7950 | ||
|
|
a180d13182 | ||
|
|
a6363a502a | ||
|
|
81bc096872 | ||
|
|
edcdb378fd | ||
|
|
4447e51588 | ||
|
|
fb8aac650f | ||
|
|
ba6b0637cc | ||
|
|
3502730dfc | ||
|
|
b95c5bb8f4 | ||
|
|
f35784aa97 | ||
|
|
3746482e8c | ||
|
|
5ed4b60b8f | ||
|
|
547da2da60 | ||
|
|
f88ed4dd5c | ||
|
|
87fc681df3 | ||
|
|
a39b2f5aa7 | ||
|
|
0b9b21eafd | ||
|
|
21f43b0dd8 | ||
|
|
3a7ba5725c | ||
|
|
2e4fa32d63 | ||
|
|
0199896d9a | ||
|
|
edd9049100 | ||
|
|
290c763901 | ||
|
|
226446a3b5 | ||
|
|
ab627db4be | ||
|
|
0f35d2368f | ||
|
|
3c276d13c4 | ||
|
|
b7c3328d43 | ||
|
|
4d8e63bd1a | ||
|
|
7fa21ce95f | ||
|
|
296da5dbcc | ||
|
|
1ec2bbd533 | ||
|
|
d67d5d8006 | ||
|
|
c4f25a77d1 | ||
|
|
52763c09f2 | ||
|
|
e77555a04f | ||
|
|
4a313a5f93 | ||
|
|
3d9587f128 | ||
|
|
66778efcc5 | ||
|
|
6be78ff283 | ||
|
|
5281f2ba64 | ||
|
|
bc322ddac4 |
@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update \
|
||||
&& apk upgrade \
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||
&& update-ca-certificates
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ const (
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -296,4 +297,5 @@ var ChannelBaseURLs = []string{
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||
func BuildURL(base string, endpoint string) string {
|
||||
u, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
end := endpoint
|
||||
if end == "" {
|
||||
end = "/"
|
||||
}
|
||||
ref, err := url.Parse(end)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
return u.ResolveReference(ref).String()
|
||||
}
|
||||
|
||||
@@ -7,4 +7,5 @@ const (
|
||||
ContextKeyUserStatus = "user_status"
|
||||
ContextKeyUserEmail = "user_email"
|
||||
ContextKeyUserGroup = "user_group"
|
||||
ContextKeyUsingGroup = "group"
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ type TaskPlatform string
|
||||
const (
|
||||
TaskPlatformSuno TaskPlatform = "suno"
|
||||
TaskPlatformMidjourney = "mj"
|
||||
TaskPlatformKling TaskPlatform = "kling"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/shopspring/decimal"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -304,6 +306,40 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
type MoonshotBalanceData struct {
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
VoucherBalance float64 `json:"voucher_balance"`
|
||||
CashBalance float64 `json:"cash_balance"`
|
||||
}
|
||||
|
||||
type MoonshotBalanceResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data MoonshotBalanceData `json:"data"`
|
||||
Scode string `json:"scode"`
|
||||
Status bool `json:"status"`
|
||||
}
|
||||
|
||||
response := MoonshotBalanceResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Status || response.Code != 0 {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
@@ -332,6 +368,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case common.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case common.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
|
||||
@@ -40,6 +40,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if channel.Type == common.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -90,7 +93,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info)
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -165,10 +168,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio)
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
@@ -312,7 +315,7 @@ func testAllChannels(notify bool) error {
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
||||
|
||||
if notify {
|
||||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||
}
|
||||
|
||||
@@ -52,6 +52,14 @@ func GetAllChannels(c *gin.Context) {
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
// type filter
|
||||
typeStr := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeStr != "" {
|
||||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
|
||||
var total int64
|
||||
|
||||
@@ -72,6 +80,14 @@ func GetAllChannels(c *gin.Context) {
|
||||
}
|
||||
// 计算 tag 总数用于分页
|
||||
total, _ = model.CountAllTags()
|
||||
} else if typeFilter >= 0 {
|
||||
channels, err := model.GetChannelsByType((p-1)*pageSize, pageSize, idSort, typeFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
total, _ = model.CountChannelsByType(typeFilter)
|
||||
} else {
|
||||
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
|
||||
if err != nil {
|
||||
@@ -82,14 +98,18 @@ func GetAllChannels(c *gin.Context) {
|
||||
total, _ = model.CountAllChannels()
|
||||
}
|
||||
|
||||
// calculate type counts
|
||||
typeCounts, _ := model.CountChannelsGroupByType()
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -114,13 +134,6 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
@@ -217,10 +230,20 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
|
||||
// calculate type counts for search results
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, channel := range channelData {
|
||||
typeCounts[int64(channel.Type)]++
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -516,6 +539,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.Key = ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if setting.GroupInUserUsableGroups("auto") {
|
||||
usableGroups["auto"] = map[string]interface{}{
|
||||
"ratio": "自动",
|
||||
"desc": setting.GetUsableGroupDescription("auto"),
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/setting/console_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -41,46 +41,48 @@ func GetStatus(c *gin.Context) {
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
|
||||
@@ -2,7 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -134,6 +137,9 @@ func init() {
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||
return m.Id
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
@@ -179,7 +185,19 @@ func ListModels(c *gin.Context) {
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
models := model.GetGroupModels(group)
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupModels(group)
|
||||
}
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
@@ -103,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -3,7 +3,6 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -13,6 +12,8 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userId := c.GetInt("id")
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package controller
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -13,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
@@ -22,7 +22,7 @@ func GetPricing(c *gin.Context) {
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
for g := range groupRatio {
|
||||
ratio, ok := setting.GetGroupGroupRatio(group, g)
|
||||
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||
if ok {
|
||||
groupRatio[g] = ratio
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func GetPricing(c *gin.Context) {
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
||||
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -56,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
|
||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
474
controller/ratio_sync.go
Normal file
474
controller/ratio_sync.go
Normal file
@@ -0,0 +1,474 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
)
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
|
||||
var upstreams []dto.UpstreamDTO
|
||||
|
||||
if len(req.Upstreams) > 0 {
|
||||
for _, u := range req.Upstreams {
|
||||
if strings.HasPrefix(u.BaseURL, "http") {
|
||||
if u.Endpoint == "" {
|
||||
u.Endpoint = defaultEndpoint
|
||||
}
|
||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
}
|
||||
} else if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
ID: ch.Id,
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isType1 {
|
||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
}
|
||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||
return
|
||||
}
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
|
||||
if len(modelRatioMap) > 0 {
|
||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
ratioAny[k] = v
|
||||
}
|
||||
converted["model_ratio"] = ratioAny
|
||||
}
|
||||
|
||||
if len(completionRatioMap) > 0 {
|
||||
compAny := make(map[string]any, len(completionRatioMap))
|
||||
for k, v := range completionRatioMap {
|
||||
compAny[k] = v
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
priceAny[k] = v
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}) map[string]map[string]dto.DifferenceItem {
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
confidenceMap := make(map[string]map[string]bool)
|
||||
|
||||
// 预处理阶段:检查pricing接口的可信度
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
|
||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||
for modelName := range allModels {
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
confidenceValues := make(map[string]bool)
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
|
||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||
}
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
Confidence: confidenceValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
delete(item.Confidence, chName)
|
||||
}
|
||||
}
|
||||
|
||||
allSame := true
|
||||
for _, v := range item.Upstreams {
|
||||
if v != "same" {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(item.Upstreams) == 0 || allSame {
|
||||
delete(ratioMap, ratioType)
|
||||
} else {
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
|
||||
if len(ratioMap) == 0 {
|
||||
delete(differences, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
return differences
|
||||
}
|
||||
|
||||
func GetSyncableChannels(c *gin.Context) {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
break
|
||||
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
|
||||
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling:
|
||||
_ = UpdateVideoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
}
|
||||
|
||||
140
controller/task_video.go
Normal file
140
controller/task_video.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(constant.TaskPlatformKling)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("FetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("Get Video Task status code: %d", resp.StatusCode)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ReadAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
var responseItem map[string]interface{}
|
||||
err = json.Unmarshal(responseBody, &responseItem)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to parse video task response body: %v, body: %s", err, string(responseBody)))
|
||||
return fmt.Errorf("Unmarshal failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
code, _ := responseItem["code"].(float64)
|
||||
if code != 0 {
|
||||
return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
}
|
||||
|
||||
data, ok := responseItem["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
common.LogError(ctx, fmt.Sprintf("Video task data format error: %s", string(responseBody)))
|
||||
return fmt.Errorf("video task data format error for task %s", taskId)
|
||||
}
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
|
||||
if status, ok := data["task_status"].(string); ok {
|
||||
switch status {
|
||||
case "submitted", "queued":
|
||||
task.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
task.Status = model.TaskStatusInProgress
|
||||
case "succeed":
|
||||
task.Status = model.TaskStatusSuccess
|
||||
task.Progress = "100%"
|
||||
if url, err := adaptor.ParseResultUrl(responseItem); err == nil {
|
||||
task.FailReason = url
|
||||
} else {
|
||||
common.LogWarn(ctx, fmt.Sprintf("Failed to get url from body for task %s: %s", task.TaskID, err.Error()))
|
||||
}
|
||||
case "failed":
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if reason, ok := data["fail_reason"].(string); ok {
|
||||
task.FailReason = reason
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If task failed, refund quota
|
||||
if task.Status == model.TaskStatusFailure {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -258,3 +258,32 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TokenBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
}
|
||||
|
||||
func DeleteTokenBatch(c *gin.Context) {
|
||||
tokenBatch := TokenBatch{}
|
||||
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
payType := "wxpay"
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = "alipay"
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = "wxpay"
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
Type: req.PaymentMethod,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
|
||||
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
}
|
||||
if setting.DefaultUseAutoGroup {
|
||||
token.Group = "auto"
|
||||
}
|
||||
if err := token.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -15,6 +15,7 @@ type ImageRequest struct {
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
|
||||
@@ -53,9 +53,11 @@ type GeneralOpenAIRequest struct {
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
|
||||
@@ -26,7 +26,7 @@ type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
@@ -178,6 +178,8 @@ type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||
// OpenRouter Params
|
||||
Cost float64 `json:"cost,omitempty"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
|
||||
38
dto/ratio_sync.go
Normal file
38
dto/ratio_sync.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package dto
|
||||
|
||||
type UpstreamDTO struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type UpstreamRequest struct {
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// TestResult 上游测试连通性结果
|
||||
type TestResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// DifferenceItem 差异项
|
||||
// Current 为本地值,可能为 nil
|
||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||
|
||||
type DifferenceItem struct {
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
}
|
||||
|
||||
type SyncableChannel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package dto
|
||||
|
||||
type VideoRequest struct {
|
||||
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||
Width int `json:"width" example:"512"` // Video width
|
||||
Height int `json:"height" example:"512"` // Video height
|
||||
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||
}
|
||||
|
||||
// VideoResponse 视频生成提交任务后的响应
|
||||
type VideoResponse struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||
type VideoTaskResponse struct {
|
||||
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||
}
|
||||
|
||||
// VideoTaskMetadata 视频任务元数据
|
||||
type VideoTaskMetadata struct {
|
||||
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||
Width int `json:"width" example:"512"` // 实际宽度
|
||||
Height int `json:"height" example:"512"` // 实际高度
|
||||
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||
}
|
||||
|
||||
// VideoTaskError 视频任务错误信息
|
||||
type VideoTaskError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
4
main.go
4
main.go
@@ -12,7 +12,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -74,7 +74,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
operation_setting.InitRatioSettings()
|
||||
ratio_setting.InitRatioSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -48,13 +49,15 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
c.Set("group", userGroup)
|
||||
c.Set(constant.ContextKeyUsingGroup, userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
showGroup := userGroup
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
@@ -162,6 +170,15 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
relayMode := relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformKling))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
relayMode := relayconstant.RelayModeGemini
|
||||
|
||||
@@ -5,10 +5,13 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||||
var channel *Channel
|
||||
var err error
|
||||
selectGroup := group
|
||||
if group == "auto" {
|
||||
if len(setting.AutoGroups) == 0 {
|
||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||
}
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
if common.DebugEnabled {
|
||||
println("autoGroup:", autoGroup)
|
||||
}
|
||||
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||||
if channel == nil {
|
||||
continue
|
||||
} else {
|
||||
c.Set("auto_group", autoGroup)
|
||||
selectGroup = autoGroup
|
||||
if common.DebugEnabled {
|
||||
println("selectGroup:", selectGroup)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||||
if err != nil {
|
||||
return nil, group, err
|
||||
}
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, group, errors.New("channel not found")
|
||||
}
|
||||
return channel, selectGroup, nil
|
||||
}
|
||||
|
||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
@@ -597,3 +597,39 @@ func CountAllTags() (int64, error) {
|
||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Get channels of specified type with pagination
|
||||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// Count channels of specific type
|
||||
func CountChannelsByType(channelType int) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Return map[type]count for all channels
|
||||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||||
type result struct {
|
||||
Type int64 `gorm:"column:type"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.Type] = r.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
@@ -46,6 +46,15 @@ func initCol() {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
} else {
|
||||
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
|
||||
if common.UsingPostgreSQL {
|
||||
logGroupCol = `"group"`
|
||||
logKeyCol = `"key"`
|
||||
} else {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
}
|
||||
// log sql type and database type
|
||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -76,6 +77,9 @@ func InitOptionMap() {
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -94,13 +98,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -123,6 +127,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
|
||||
|
||||
// 自动添加所有注册的模型配置
|
||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||
@@ -192,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ImageDownloadPermission = intValue
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
|
||||
boolValue := value == "true"
|
||||
switch key {
|
||||
case "PasswordRegisterEnabled":
|
||||
@@ -261,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
case "WorkerAllowHttpImageRequestEnabled":
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
case "ExposeRatioEnabled":
|
||||
ratio_setting.SetExposeRatioEnabled(boolValue)
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -287,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "AutoGroups":
|
||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
@@ -352,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DataExportDefaultTime":
|
||||
common.DataExportDefaultTime = value
|
||||
case "ModelRatio":
|
||||
err = operation_setting.UpdateModelRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupRatioByJSONString(value)
|
||||
case "GroupGroupRatio":
|
||||
err = setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = operation_setting.UpdateCompletionRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = operation_setting.UpdateModelPriceByJSONString(value)
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = operation_setting.UpdateCacheRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
//case "ChatLink":
|
||||
@@ -381,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -65,14 +65,14 @@ func updatePricing() {
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
}
|
||||
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
modelRatio, _ := operation_setting.GetModelRatio(model)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(model)
|
||||
pricing.ModelRatio = modelRatio
|
||||
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
|
||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
|
||||
@@ -327,3 +327,37 @@ func CountUserTokens(userId int) (int64, error) {
|
||||
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
|
||||
func BatchDeleteTokens(ids []int, userId int) (int, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, errors.New("ids 不能为空!")
|
||||
}
|
||||
|
||||
tx := DB.Begin()
|
||||
|
||||
var tokens []Token
|
||||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if common.RedisEnabled {
|
||||
gopool.Go(func() {
|
||||
for _, t := range tokens {
|
||||
_ = cacheDeleteToken(t.Key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return len(tokens), nil
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
|
||||
}
|
||||
|
||||
func batchUpdate() {
|
||||
// check if there's any data to update
|
||||
hasData := false
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
if len(batchUpdateStores[i]) > 0 {
|
||||
hasData = true
|
||||
batchUpdateLocks[i].Unlock()
|
||||
break
|
||||
}
|
||||
batchUpdateLocks[i].Unlock()
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
|
||||
@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||
}
|
||||
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
preConsumedTokens = promptTokens
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
audioRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
|
||||
ParseResultUrl(resp map[string]any) (string, error)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -122,6 +123,21 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
}
|
||||
|
||||
if textRequest.Reasoning != nil {
|
||||
var reasoning openrouter.RequestReasoning
|
||||
if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
budgetTokens := reasoning.MaxTokens
|
||||
if budgetTokens > 0 {
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: &budgetTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.Stop != nil {
|
||||
// stop maybe string/array string, convert to array string
|
||||
switch textRequest.Stop.(type) {
|
||||
@@ -454,6 +470,7 @@ type ClaudeResponseInfo struct {
|
||||
Model string
|
||||
ResponseText strings.Builder
|
||||
Usage *dto.Usage
|
||||
Done bool
|
||||
}
|
||||
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
@@ -461,20 +478,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
|
||||
// message_start, 获取usage
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != "" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
|
||||
// 判断是否完整
|
||||
claudeInfo.Done = true
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
@@ -506,25 +535,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
@@ -544,29 +563,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 说明流模式建立失败,可能为官方出错
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
//
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -619,10 +634,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
||||
}
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
|
||||
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return nil, usage
|
||||
|
||||
@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
var currentEvent string
|
||||
var currentData string
|
||||
var usage dto.Usage
|
||||
var usage = &dto.Usage{}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
if line == "" {
|
||||
if currentEvent != "" && currentData != "" {
|
||||
// handle last event
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
currentEvent = ""
|
||||
currentData = ""
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
// Last event
|
||||
if currentEvent != "" && currentData != "" {
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
helper.Done(c)
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||
|
||||
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
return true
|
||||
})
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return nil, usage
|
||||
|
||||
@@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.OriginModelName, "-thinking-") {
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,6 +140,7 @@ type GeminiChatGenerationConfig struct {
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -35,23 +36,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 检查是否有候选响应
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return nil, &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
@@ -88,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||
@@ -102,13 +92,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
@@ -121,7 +114,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.ObjectData(c, geminiResponse)
|
||||
err = helper.StringData(c, data)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
@@ -135,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
}
|
||||
}
|
||||
|
||||
// 计算最终使用量
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||
//helper.Done(c)
|
||||
|
||||
@@ -39,11 +39,100 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
|
||||
// Gemini 允许的思考预算范围
|
||||
const (
|
||||
pro25MinBudget = 128
|
||||
pro25MaxBudget = 32768
|
||||
flash25MaxBudget = 24576
|
||||
pro25MinBudget = 128
|
||||
pro25MaxBudget = 32768
|
||||
flash25MaxBudget = 24576
|
||||
flash25LiteMinBudget = 512
|
||||
flash25LiteMaxBudget = 24576
|
||||
)
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
|
||||
if is25FlashLite {
|
||||
if budget < flash25LiteMinBudget {
|
||||
return flash25LiteMinBudget
|
||||
}
|
||||
if budget > flash25LiteMaxBudget {
|
||||
return flash25LiteMaxBudget
|
||||
}
|
||||
} else if isNew25Pro {
|
||||
if budget < pro25MinBudget {
|
||||
return pro25MinBudget
|
||||
}
|
||||
if budget > pro25MaxBudget {
|
||||
return pro25MaxBudget
|
||||
}
|
||||
} else { // 其他模型
|
||||
if budget < 0 {
|
||||
return 0
|
||||
}
|
||||
if budget > flash25MaxBudget {
|
||||
return flash25MaxBudget
|
||||
}
|
||||
}
|
||||
return budget
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if strings.Contains(modelName, "-thinking-") {
|
||||
parts := strings.SplitN(modelName, "-thinking-", 2)
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-thinking") {
|
||||
unsupportedModels := []string{
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
}
|
||||
isUnsupported := false
|
||||
for _, unsupportedModel := range unsupportedModels {
|
||||
if strings.HasPrefix(modelName, unsupportedModel) {
|
||||
isUnsupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
|
||||
@@ -64,100 +153,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.OriginModelName, "-thinking-") {
|
||||
parts := strings.SplitN(info.OriginModelName, "-thinking-", 2)
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
// 从模型名称成功解析预算
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if isNew25Pro {
|
||||
// 新的2.5pro模型:ThinkingBudget范围为128-32768
|
||||
if budgetTokens < pro25MinBudget {
|
||||
budgetTokens = pro25MinBudget
|
||||
} else if budgetTokens > pro25MaxBudget {
|
||||
budgetTokens = pro25MaxBudget
|
||||
}
|
||||
} else {
|
||||
// 其他模型:ThinkingBudget范围为0-24576
|
||||
if budgetTokens < 0 {
|
||||
budgetTokens = 0
|
||||
} else if budgetTokens > flash25MaxBudget {
|
||||
budgetTokens = flash25MaxBudget
|
||||
}
|
||||
}
|
||||
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(budgetTokens),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
// 如果解析失败,则不设置ThinkingConfig,静默处理
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 保留旧逻辑以兼容
|
||||
// 硬编码不支持 ThinkingBudget 的旧模型
|
||||
unsupportedModels := []string{
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
}
|
||||
|
||||
isUnsupported := false
|
||||
for _, unsupportedModel := range unsupportedModels {
|
||||
if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
|
||||
isUnsupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
|
||||
// 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if isNew25Pro {
|
||||
// 新的2.5pro模型:ThinkingBudget范围为128-32768
|
||||
if budgetTokens == 0 || budgetTokens < 128 {
|
||||
budgetTokens = 128
|
||||
} else if budgetTokens > 32768 {
|
||||
budgetTokens = 32768
|
||||
}
|
||||
} else {
|
||||
// 其他模型:ThinkingBudget范围为0-24576
|
||||
if budgetTokens == 0 || budgetTokens > 24576 {
|
||||
budgetTokens = 24576
|
||||
}
|
||||
}
|
||||
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(int(budgetTokens)),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
// 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if !isNew25Pro {
|
||||
// 只有非新2.5pro模型才支持-nothinking
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ThinkingAdaptor(&geminiRequest, info)
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
@@ -324,7 +320,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
|
||||
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
||||
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
||||
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
|
||||
url := part.GetImageMedia().Url
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
@@ -382,7 +379,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
if len(content.Parts) > 0 {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
|
||||
@@ -159,6 +159,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if len(request.Usage) == 0 {
|
||||
request.Usage = json.RawMessage(`{"include":true}`)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||
@@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
forceFormat := false
|
||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||
forceFormat = forceFmt
|
||||
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
@@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
// the status code has been judged before, if there is a body reading failure,
|
||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// and can be terminated directly.
|
||||
defer resp.Body.Close()
|
||||
usage := &dto.Usage{}
|
||||
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
if err = c.ShouldBind(&reqBody); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||
reqFp, err := reqBody.File.Open()
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
defer reqFp.Close()
|
||||
defer reqFp.Close()
|
||||
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
|
||||
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
||||
tempStr := responseTextBuilder.String()
|
||||
if len(tempStr) > 0 {
|
||||
// 非正常结束,使用输出文本的 token 数量
|
||||
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||
usage.CompletionTokens = completionTokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
|
||||
312
relay/channel/task/kling/adaptor.go
Normal file
312
relay/channel/task/kling/adaptor.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package kling
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
TaskID string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
|
||||
// apiKey format: "access_key|secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 2 {
|
||||
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("kling_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/v1/videos/image2video", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
token, err := a.createJWTToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JWT token: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Kling specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("kling_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
|
||||
body := a.convertToRequestPayload(&req)
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt Kling response parse first.
|
||||
var kResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskID})
|
||||
return kResp.Data.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
// Fallback generic task response.
|
||||
var generic dto.TaskResponse[string]
|
||||
if err := json.Unmarshal(responseBody, &generic); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !generic.IsSuccess() {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
||||
return generic.Data, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
url := fmt.Sprintf("%s/v1/videos/image2video/%s", baseUrl, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := a.createJWTTokenWithKey(key)
|
||||
if err != nil {
|
||||
token = key
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "kling"
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) *requestPayload {
|
||||
r := &requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
Model: req.Model,
|
||||
ModelName: req.Model,
|
||||
CfgScale: 0.5,
|
||||
}
|
||||
if r.Model == "" {
|
||||
r.Model = "kling-v1"
|
||||
r.ModelName = "kling-v1"
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1024x1024", "512x512":
|
||||
return "1:1"
|
||||
case "1280x720", "1920x1080":
|
||||
return "16:9"
|
||||
case "720x1280", "1080x1920":
|
||||
return "9:16"
|
||||
default:
|
||||
return "1:1"
|
||||
}
|
||||
}
|
||||
|
||||
func defaultString(s, def string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultInt(v int, def int) int {
|
||||
if v == 0 {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// ============================
|
||||
// JWT helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||
}
|
||||
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
||||
if accessKey == "" || secretKey == "" {
|
||||
return "", fmt.Errorf("access key and secret key are required")
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": accessKey,
|
||||
"exp": now + 1800, // 30 minutes
|
||||
"nbf": now - 5,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["typ"] = "JWT"
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
|
||||
// ParseResultUrl 提取视频任务结果的 url
|
||||
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||
data, ok := resp["data"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("data field not found or invalid")
|
||||
}
|
||||
taskResult, ok := data["task_result"].(map[string]any)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("task_result field not found or invalid")
|
||||
}
|
||||
videos, ok := taskResult["videos"].([]interface{})
|
||||
if !ok || len(videos) == 0 {
|
||||
return "", fmt.Errorf("videos field not found or empty")
|
||||
}
|
||||
video, ok := videos[0].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("video item invalid")
|
||||
}
|
||||
url, ok := video["url"].(string)
|
||||
if !ok || url == "" {
|
||||
return "", fmt.Errorf("url field not found or invalid")
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseResultUrl(resp map[string]any) (string, error) {
|
||||
return "", nil // todo implement this method if needed
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = tencentHandler(c, resp)
|
||||
}
|
||||
|
||||
@@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// suffix -thinking and -nothinking
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
}
|
||||
}
|
||||
@@ -123,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package volcengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
// 获取所有表单字段
|
||||
formData := c.Request.PostForm
|
||||
// 遍历表单字段并打印输出
|
||||
for key, values := range formData {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
writer.WriteField(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the multipart form to handle both single image and multiple images
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
}
|
||||
|
||||
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
||||
// Check if "image" field exists in any form, including array notation
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range c.Request.MultipartForm.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
foundArrayImages = true
|
||||
for _, file := range files {
|
||||
imageFiles = append(imageFiles, file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no image fields found at all
|
||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process all image files
|
||||
for i, fileHeader := range imageFiles {
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// If multiple images, use image[] as the field name
|
||||
fieldName := "image"
|
||||
if len(imageFiles) > 1 {
|
||||
fieldName = "image[]"
|
||||
}
|
||||
|
||||
// Determine MIME type based on file extension
|
||||
mimeType := detectImageMimeType(fileHeader.Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
|
||||
part, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle mask file if present
|
||||
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
||||
maskFile, err := maskFiles[0].Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open mask file")
|
||||
}
|
||||
defer maskFile.Close()
|
||||
|
||||
// Determine MIME type for mask file
|
||||
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
|
||||
maskPart, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return nil, errors.New("create form file failed for mask")
|
||||
}
|
||||
|
||||
if _, err := io.Copy(maskPart, maskFile); err != nil {
|
||||
return nil, errors.New("copy mask file failed")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("no multipart form data found")
|
||||
}
|
||||
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return bytes.NewReader(requestBody.Bytes()), nil
|
||||
|
||||
default:
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
// detectImageMimeType determines the MIME type based on the file extension
|
||||
func detectImageMimeType(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
default:
|
||||
// Try to detect from extension if possible
|
||||
if strings.HasPrefix(ext, ".jp") {
|
||||
return "image/jpeg"
|
||||
}
|
||||
// Default to png as a fallback
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
@@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
relayInfo.IsStream = true
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
textRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
@@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
|
||||
@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
|
||||
}
|
||||
|
||||
const (
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
RelayFormatGemini = "gemini"
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
RelayFormatGemini = "gemini"
|
||||
RelayFormatOpenAIResponses = "openai_responses"
|
||||
RelayFormatOpenAIAudio = "openai_audio"
|
||||
RelayFormatOpenAIImage = "openai_image"
|
||||
RelayFormatRerank = "rerank"
|
||||
RelayFormatEmbedding = "embedding"
|
||||
)
|
||||
|
||||
type RerankerInfo struct {
|
||||
@@ -60,8 +65,8 @@ type RelayInfo struct {
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
UserGroup string
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
info.RelayFormat = RelayFormatRerank
|
||||
info.RerankerInfo = &RerankerInfo{
|
||||
Documents: req.Documents,
|
||||
ReturnDocuments: req.GetReturnDocuments(),
|
||||
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIAudio
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatEmbedding
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayMode = relayconstant.RelayModeResponses
|
||||
info.RelayFormat = RelayFormatOpenAIResponses
|
||||
|
||||
info.SupportStreamOptions = false
|
||||
|
||||
info.ResponsesUsageInfo = &ResponsesUsageInfo{
|
||||
BuiltInTools: make(map[string]*BuildInToolInfo),
|
||||
}
|
||||
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatGemini
|
||||
info.ShouldIncludeUsage = false
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIImage
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
@@ -184,7 +219,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
tokenId := c.GetInt("token_id")
|
||||
tokenKey := c.GetString("token_key")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
@@ -204,7 +238,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
TokenId: tokenId,
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
UsingGroup: c.GetString(constant.ContextKeyUsingGroup),
|
||||
UserGroup: c.GetString(constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
@@ -243,10 +277,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
// responses 模式不支持 StreamOptions
|
||||
if relayconstant.RelayModeResponses == info.RelayMode {
|
||||
info.SupportStreamOptions = false
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,9 @@ const (
|
||||
RelayModeSunoFetchByID
|
||||
RelayModeSunoSubmit
|
||||
|
||||
RelayModeKlingFetchByID
|
||||
RelayModeKlingSubmit
|
||||
|
||||
RelayModeRerank
|
||||
|
||||
RelayModeResponses
|
||||
@@ -133,3 +136,13 @@ func Path2RelaySuno(method, path string) int {
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelayKling(method, path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||
relayMode = RelayModeKlingSubmit
|
||||
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||
relayMode = RelayModeKlingFetchByID
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||
return token
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
|
||||
|
||||
var embeddingRequest *dto.EmbeddingRequest
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
embeddingRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -59,7 +60,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
@@ -71,9 +72,36 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
||||
}
|
||||
|
||||
inputText := strings.Join(inputTexts, "\n")
|
||||
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||
info.PromptTokens = inputTokens
|
||||
return inputTokens, err
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func trimModelThinking(modelName string) string {
|
||||
// 去除模型名称中的 -nothinking 后缀
|
||||
if strings.HasSuffix(modelName, "-nothinking") {
|
||||
return strings.TrimSuffix(modelName, "-nothinking")
|
||||
}
|
||||
// 去除模型名称中的 -thinking 后缀
|
||||
if strings.HasSuffix(modelName, "-thinking") {
|
||||
return strings.TrimSuffix(modelName, "-thinking")
|
||||
}
|
||||
|
||||
// 去除模型名称中的 -thinking-number
|
||||
if strings.Contains(modelName, "-thinking-") {
|
||||
parts := strings.Split(modelName, "-thinking-")
|
||||
if len(parts) > 1 {
|
||||
return parts[0] + "-thinking"
|
||||
}
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -83,7 +111,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
|
||||
// 检查 Gemini 流式模式
|
||||
checkGeminiStreamMode(c, relayInfo)
|
||||
@@ -97,7 +125,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||
}
|
||||
@@ -106,13 +134,28 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
promptTokens := value.(int)
|
||||
relayInfo.SetPromptTokens(promptTokens)
|
||||
} else {
|
||||
promptTokens, err := getGeminiInputTokens(req, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||
}
|
||||
promptTokens := getGeminiInputTokens(req, relayInfo)
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
if isNoThinkingRequest(req) {
|
||||
// check is thinking
|
||||
if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
|
||||
// try to get no thinking model price
|
||||
noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
|
||||
containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
|
||||
if containPrice {
|
||||
relayInfo.OriginModelName = noThinkingModelName
|
||||
relayInfo.UpstreamModelName = noThinkingModelName
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.GenerationConfig.ThinkingConfig == nil {
|
||||
gemini.ThinkingAdaptor(req, relayInfo)
|
||||
}
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
@@ -155,14 +198,33 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(requestBody))
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
|
||||
if openaiErr != nil {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
@@ -4,12 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
common2 "one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/common"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
||||
info.UpstreamModelName = currentModel
|
||||
}
|
||||
}
|
||||
if request != nil {
|
||||
switch info.RelayFormat {
|
||||
case common.RelayFormatGemini:
|
||||
// Gemini 模型映射
|
||||
case common.RelayFormatClaude:
|
||||
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
|
||||
claudeRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIResponses:
|
||||
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
openAIResponsesRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIAudio:
|
||||
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
|
||||
openAIAudioRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIImage:
|
||||
if imageRequest, ok := request.(*dto.ImageRequest); ok {
|
||||
imageRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatRerank:
|
||||
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
|
||||
rerankRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatEmbedding:
|
||||
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
|
||||
embeddingRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
default:
|
||||
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||
openAIRequest.Model = info.UpstreamModelName
|
||||
} else {
|
||||
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,12 +5,17 @@ import (
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GroupRatioInfo struct {
|
||||
GroupRatio float64
|
||||
GroupSpecialRatio float64
|
||||
HasSpecialRatio bool
|
||||
}
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
@@ -18,23 +23,51 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
GroupRatio float64
|
||||
UserGroupRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
}
|
||||
|
||||
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
|
||||
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
|
||||
groupRatioInfo := GroupRatioInfo{
|
||||
GroupRatio: 1.0, // default ratio
|
||||
GroupSpecialRatio: -1,
|
||||
}
|
||||
|
||||
// check auto group
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("final group: %s", autoGroup))
|
||||
}
|
||||
relayInfo.UsingGroup = autoGroup.(string)
|
||||
}
|
||||
|
||||
// check user group special ratio
|
||||
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
||||
if ok {
|
||||
// user group special ratio
|
||||
groupRatioInfo.GroupSpecialRatio = userGroupRatio
|
||||
groupRatioInfo.GroupRatio = userGroupRatio
|
||||
groupRatioInfo.HasSpecialRatio = true
|
||||
} else {
|
||||
// normal group ratio
|
||||
groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
}
|
||||
|
||||
return groupRatioInfo
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
|
||||
if ok {
|
||||
groupRatio = userGroupRatio
|
||||
}
|
||||
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
@@ -47,7 +80,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
preConsumedTokens = promptTokens + maxTokens
|
||||
}
|
||||
var success bool
|
||||
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
|
||||
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
|
||||
if !success {
|
||||
acceptUnsetRatio := false
|
||||
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
|
||||
@@ -60,22 +93,21 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
||||
}
|
||||
}
|
||||
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
}
|
||||
|
||||
priceData := PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
GroupRatio: groupRatio,
|
||||
UserGroupRatio: userGroupRatio,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
@@ -90,12 +122,41 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
type PerCallPriceData struct {
|
||||
ModelPrice float64
|
||||
Quota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
priceData := PerCallPriceData{
|
||||
ModelPrice: modelPrice,
|
||||
Quota: quota,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
}
|
||||
return priceData
|
||||
}
|
||||
|
||||
func ContainPriceOrRatio(modelName string) bool {
|
||||
_, ok := operation_setting.GetModelPrice(modelName, false)
|
||||
_, ok := ratio_setting.GetModelPrice(modelName, false)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
_, ok = operation_setting.GetModelRatio(modelName)
|
||||
_, ok = ratio_setting.GetModelRatio(modelName)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"one-api/relay/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -44,6 +46,11 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
if info.ApiType == constant.APITypeVolcEngine {
|
||||
watermark := formData.Has("watermark")
|
||||
imageRequest.Watermark = &watermark
|
||||
}
|
||||
default:
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
@@ -102,7 +109,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
}
|
||||
|
||||
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoImage(c)
|
||||
|
||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||
if err != nil {
|
||||
@@ -110,13 +117,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
imageRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
@@ -162,7 +167,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// reset model price
|
||||
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -174,18 +174,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
@@ -193,9 +184,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
if userQuota-priceData.Quota < 0 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
@@ -210,26 +200,18 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
}
|
||||
defer func() {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
//err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
}
|
||||
}()
|
||||
midjResponse := &mjResp.Response
|
||||
@@ -250,7 +232,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
Quota: priceData.Quota,
|
||||
}
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
@@ -480,18 +462,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
@@ -499,9 +472,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
if consumeQuota && userQuota-priceData.Quota < 0 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
@@ -516,22 +488,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
|
||||
defer func() {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -559,7 +526,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
Quota: priceData.Quota,
|
||||
}
|
||||
if midjResponse.Code == 3 {
|
||||
//无实例账号自动禁用渠道(No available account instance)
|
||||
|
||||
@@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
@@ -107,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
textRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
var promptTokens int
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -252,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case relayconstant.RelayModeModerations:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
default:
|
||||
err = errors.New("unknown relay mode")
|
||||
promptTokens = 0
|
||||
@@ -361,9 +360,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
cacheRatio := priceData.CacheRatio
|
||||
imageRatio := priceData.ImageRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
userGroupRatio := priceData.UserGroupRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
@@ -511,7 +509,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
@@ -543,5 +541,5 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"one-api/relay/channel/palm"
|
||||
"one-api/relay/channel/perplexity"
|
||||
"one-api/relay/channel/siliconflow"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
@@ -101,6 +102,8 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
||||
// return &aiproxy.Adaptor{}
|
||||
case commonconstant.TaskPlatformSuno:
|
||||
return &suno.TaskAdaptor{}
|
||||
case commonconstant.TaskPlatformKling:
|
||||
return &kling.TaskAdaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -15,8 +14,9 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -38,9 +38,12 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
if platform == constant.TaskPlatformKling {
|
||||
modelName = relayInfo.OriginModelName
|
||||
}
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
@@ -49,8 +52,14 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelPrice * groupRatio
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
var ratio float64
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
||||
if hasUserGroupRatio {
|
||||
ratio = modelPrice * userGroupRatio
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
@@ -119,12 +128,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
|
||||
gRatio := groupRatio
|
||||
if hasUserGroupRatio {
|
||||
gRatio = userGroupRatio
|
||||
}
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
if hasUserGroupRatio {
|
||||
other["user_group_ratio"] = userGroupRatio
|
||||
}
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
@@ -137,10 +153,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
relayInfo.ConsumeQuota = true
|
||||
// insert task
|
||||
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
|
||||
task := model.InitTask(platform, relayInfo)
|
||||
task.TaskID = taskID
|
||||
task.Quota = quota
|
||||
task.Data = taskData
|
||||
task.Action = relayInfo.Action
|
||||
err = task.Insert()
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||
@@ -150,8 +167,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
|
||||
}
|
||||
|
||||
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||
@@ -226,6 +244,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
||||
return
|
||||
}
|
||||
|
||||
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||
taskId := c.Param("id")
|
||||
userId := c.GetInt("id")
|
||||
|
||||
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||
return &dto.TaskDto{
|
||||
TaskID: task.TaskID,
|
||||
|
||||
@@ -14,12 +14,10 @@ import (
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
for _, document := range rerankRequest.Documents {
|
||||
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
||||
if err == nil {
|
||||
token += tkm
|
||||
}
|
||||
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||
token += tkm
|
||||
}
|
||||
return token
|
||||
}
|
||||
@@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
rerankRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
|
||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
|
||||
inputTokens := service.CountTokenInput(req.Input, req.Model)
|
||||
info.PromptTokens = inputTokens
|
||||
return inputTokens, err
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||
}
|
||||
req.Model = relayInfo.UpstreamModelName
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
promptTokens := value.(int)
|
||||
relayInfo.SetPromptTokens(promptTokens)
|
||||
} else {
|
||||
promptTokens, err := getInputTokens(req, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||
}
|
||||
promptTokens := getInputTokens(req, relayInfo)
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
//if constant.ShouldCheckPromptSensitive() {
|
||||
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||
// if err != nil {
|
||||
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
// }
|
||||
//}
|
||||
|
||||
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||
//// count messages token error 计算promptTokens错误
|
||||
//if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
//}
|
||||
//
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||
//}
|
||||
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
relayInfo.UsePrice = true
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
return openaiErr
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
||||
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
@@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) {
|
||||
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
|
||||
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
||||
}
|
||||
ratioSyncRoute := apiRouter.Group("/ratio_sync")
|
||||
ratioSyncRoute.Use(middleware.RootAuth())
|
||||
{
|
||||
ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
|
||||
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
|
||||
}
|
||||
channelRoute := apiRouter.Group("/channel")
|
||||
channelRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
@@ -118,6 +125,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
tokenRoute.POST("/", controller.AddToken)
|
||||
tokenRoute.PUT("/", controller.UpdateToken)
|
||||
tokenRoute.DELETE("/:id", controller.DeleteToken)
|
||||
tokenRoute.POST("/batch", controller.DeleteTokenBatch)
|
||||
}
|
||||
redemptionRoute := apiRouter.Group("/redemption")
|
||||
redemptionRoute.Use(middleware.AdminAuth())
|
||||
|
||||
@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
SetApiRouter(router)
|
||||
SetDashboardRouter(router)
|
||||
SetRelayRouter(router)
|
||||
SetVideoRouter(router)
|
||||
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
||||
if common.IsMasterNode && frontendBaseUrl != "" {
|
||||
frontendBaseUrl = ""
|
||||
|
||||
17
router/video-router.go
Normal file
17
router/video-router.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetVideoRouter(router *gin.Engine) {
|
||||
videoV1Router := router.Group("/v1")
|
||||
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||
}
|
||||
}
|
||||
@@ -59,6 +59,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
|
||||
return true
|
||||
case "billing_not_active":
|
||||
return true
|
||||
case "pre_consume_token_quota_failed":
|
||||
return true
|
||||
}
|
||||
switch err.Error.Type {
|
||||
case "insufficient_quota":
|
||||
|
||||
@@ -276,12 +276,15 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if info.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
if info.ClaudeConvertInfo.Usage != nil {
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
|
||||
OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
|
||||
@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
|
||||
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
lowerText := strings.ToLower(text)
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: text,
|
||||
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
|
||||
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
lowerText := strings.ToLower(text)
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
if !strings.HasPrefix(lowerText, "get file base64 from url") {
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
claudeError := dto.ClaudeError{
|
||||
Message: text,
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
@@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
// Convert to base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
mimeType := resp.Header.Get("Content-Type")
|
||||
if len(strings.Split(mimeType, ";")) > 1 {
|
||||
// If Content-Type has parameters, take the first part
|
||||
mimeType = strings.Split(mimeType, ";")[0]
|
||||
}
|
||||
if mimeType == "application/octet-stream" {
|
||||
if common.DebugEnabled {
|
||||
println("MIME type is application/octet-stream, trying to guess from URL or filename")
|
||||
}
|
||||
// try to guess the MIME type from the url last segment
|
||||
urlParts := strings.Split(url, "/")
|
||||
if len(urlParts) > 0 {
|
||||
lastSegment := urlParts[len(urlParts)-1]
|
||||
if strings.Contains(lastSegment, ".") {
|
||||
// Extract the file extension
|
||||
filename := strings.Split(lastSegment, ".")
|
||||
if len(filename) > 1 {
|
||||
ext := strings.ToLower(filename[len(filename)-1])
|
||||
// Guess MIME type based on file extension
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// try to guess the MIME type from the file extension
|
||||
fileName := resp.Header.Get("Content-Disposition")
|
||||
if fileName != "" {
|
||||
// Extract the filename from the Content-Disposition header
|
||||
parts := strings.Split(fileName, ";")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
|
||||
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
|
||||
// Remove quotes if present
|
||||
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
|
||||
fileName = fileName[1 : len(fileName)-1]
|
||||
}
|
||||
// Guess MIME type based on file extension
|
||||
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.LocalFileData{
|
||||
Base64Data: base64Data,
|
||||
MimeType: resp.Header.Get("Content-Type"),
|
||||
MimeType: mimeType,
|
||||
Size: int64(len(fileBytes)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetMimeTypeByExtension(ext string) string {
|
||||
// Convert to lowercase for case-insensitive comparison
|
||||
ext = strings.ToLower(ext)
|
||||
switch ext {
|
||||
// Text files
|
||||
case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
|
||||
return "text/plain"
|
||||
|
||||
// Image files
|
||||
case "jpg", "jpeg":
|
||||
return "image/jpeg"
|
||||
case "png":
|
||||
return "image/png"
|
||||
case "gif":
|
||||
return "image/gif"
|
||||
|
||||
// Audio files
|
||||
case "mp3":
|
||||
return "audio/mp3"
|
||||
case "wav":
|
||||
return "audio/wav"
|
||||
case "mpeg":
|
||||
return "audio/mpeg"
|
||||
|
||||
// Video files
|
||||
case "mp4":
|
||||
return "video/mp4"
|
||||
case "wmv":
|
||||
return "video/wmv"
|
||||
case "flv":
|
||||
return "video/flv"
|
||||
case "mov":
|
||||
return "video/mov"
|
||||
case "mpg":
|
||||
return "video/mpg"
|
||||
case "avi":
|
||||
return "video/avi"
|
||||
case "mpegps":
|
||||
return "video/mpegps"
|
||||
|
||||
// Document files
|
||||
case "pdf":
|
||||
return "application/pdf"
|
||||
|
||||
default:
|
||||
return "application/octet-stream" // Default for unknown types
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -63,3 +64,13 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
info["cache_creation_ratio"] = cacheCreationRatio
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = priceData.ModelPrice
|
||||
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
||||
if priceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
return other
|
||||
}
|
||||
|
||||
115
service/quota.go
115
service/quota.go
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
@@ -10,7 +12,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -45,9 +47,9 @@ func calculateAudioQuota(info QuotaInfo) int {
|
||||
return int(quota.IntPart())
|
||||
}
|
||||
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
|
||||
|
||||
groupRatio := decimal.NewFromFloat(info.GroupRatio)
|
||||
modelRatio := decimal.NewFromFloat(info.ModelRatio)
|
||||
@@ -93,12 +95,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
groupRatio = userGroupRatio
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(modelName)
|
||||
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
|
||||
log.Printf("final group ratio: %f", groupRatio)
|
||||
relayInfo.UsingGroup = autoGroup.(string)
|
||||
}
|
||||
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
modelRatio, _ := operation_setting.GetModelRatio(modelName)
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
@@ -112,7 +123,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
ModelName: modelName,
|
||||
UsePrice: relayInfo.UsePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: groupRatio,
|
||||
GroupRatio: actualGroupRatio,
|
||||
}
|
||||
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
@@ -134,8 +145,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
}
|
||||
|
||||
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, usePrice bool, extraContent string) {
|
||||
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
@@ -145,15 +155,15 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
TextTokens: textInputTokens,
|
||||
@@ -166,7 +176,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
ModelName: modelName,
|
||||
UsePrice: usePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: actualGroupRatio,
|
||||
GroupRatio: groupRatio,
|
||||
}
|
||||
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
@@ -198,9 +208,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -214,15 +224,25 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := priceData.CompletionRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
userGroupRatio := priceData.UserGroupRatio
|
||||
cacheRatio := priceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
cacheCreationRatio := priceData.CacheCreationRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
if relayInfo.ChannelType == common.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
|
||||
if promptTokens >= maybeCacheCreationTokens {
|
||||
cacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
promptTokens -= cacheCreationTokens
|
||||
}
|
||||
|
||||
calculateQuota := 0.0
|
||||
if !priceData.UsePrice {
|
||||
calculateQuota = float64(promptTokens)
|
||||
@@ -265,9 +285,30 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio)
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
|
||||
if priceData.CacheCreationRatio == 1 {
|
||||
return 0
|
||||
}
|
||||
quotaPrice := priceData.ModelRatio / common.QuotaPerUnit
|
||||
promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio
|
||||
promptCacheReadPrice := quotaPrice * priceData.CacheRatio
|
||||
completionPrice := quotaPrice * priceData.CompletionRatio
|
||||
|
||||
cost := usage.Cost
|
||||
totalPromptTokens := float64(usage.PromptTokens)
|
||||
completionTokens := float64(usage.CompletionTokens)
|
||||
promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens)
|
||||
|
||||
return int(math.Round((cost -
|
||||
totalPromptTokens*quotaPrice +
|
||||
promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) -
|
||||
completionTokens*completionPrice) /
|
||||
(promptCacheCreatePrice - quotaPrice)))
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -281,21 +322,15 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
TextTokens: textInputTokens,
|
||||
@@ -308,7 +343,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
UsePrice: usePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: actualGroupRatio,
|
||||
GroupRatio: groupRatio,
|
||||
}
|
||||
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
@@ -348,9 +383,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||
|
||||
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
||||
toolTokens := CountTokenInput(countStr, request.Model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
||||
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens, err := CountTokenInput(request.System, model)
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
switch request.Type {
|
||||
case dto.RealtimeEventTypeSessionUpdate:
|
||||
if request.Session != nil {
|
||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||||
textToken += msgTokens
|
||||
}
|
||||
case dto.RealtimeEventResponseAudioDelta:
|
||||
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
audioToken += atk
|
||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||
// count text token
|
||||
tkm, err := CountTextToken(request.Delta, model)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
||||
}
|
||||
tkm := CountTextToken(request.Delta, model)
|
||||
textToken += tkm
|
||||
case dto.RealtimeEventInputAudioBufferAppend:
|
||||
// count audio token
|
||||
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
case "message":
|
||||
for _, content := range request.Item.Content {
|
||||
if content.Type == "input_text" {
|
||||
tokens, err := CountTextToken(content.Text, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
tokens := CountTextToken(content.Text, model)
|
||||
textToken += tokens
|
||||
}
|
||||
}
|
||||
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
if !info.IsFirstRequest {
|
||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||
for _, tool := range info.RealtimeTools {
|
||||
toolTokens, err := CountTokenInput(tool, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
toolTokens := CountTokenInput(tool, model)
|
||||
textToken += 8
|
||||
textToken += toolTokens
|
||||
}
|
||||
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string) (int, error) {
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTextToken(v, model)
|
||||
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
||||
tkm := CountTokenInput(tool.Function.Name, model)
|
||||
tokens += tkm
|
||||
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
||||
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTTSToken(text string, model string) (int, error) {
|
||||
func CountTTSToken(text string, model string) int {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text), nil
|
||||
return utf8.RuneCountInString(text)
|
||||
} else {
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTextToken(text string, model string) (int, error) {
|
||||
var err error
|
||||
func CountTextToken(text string, model string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text), err
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
// return 0, errors.New("unknown relay mode")
|
||||
//}
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm, err := CountTextToken(responseText, modeName)
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
return usage
|
||||
}
|
||||
|
||||
func ValidUsage(usage *dto.Usage) bool {
|
||||
|
||||
31
setting/auto_group.go
Normal file
31
setting/auto_group.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package setting
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var AutoGroups = []string{
|
||||
"default",
|
||||
}
|
||||
|
||||
var DefaultUseAutoGroup = false
|
||||
|
||||
func ContainsAutoGroup(group string) bool {
|
||||
for _, autoGroup := range AutoGroups {
|
||||
if autoGroup == group {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func UpdateAutoGroupsByJsonString(jsonString string) error {
|
||||
AutoGroups = make([]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &AutoGroups)
|
||||
}
|
||||
|
||||
func AutoGroups2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(AutoGroups)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
@@ -17,6 +17,8 @@ const (
|
||||
const (
|
||||
// Gemini Audio Input Price
|
||||
Gemini25FlashPreviewInputAudioPrice = 1.00
|
||||
Gemini25FlashProductionInputAudioPrice = 1.00 // for `gemini-2.5-flash`
|
||||
Gemini25FlashLitePreviewInputAudioPrice = 0.50
|
||||
Gemini25FlashNativeAudioInputAudioPrice = 3.00
|
||||
Gemini20FlashInputAudioPrice = 0.70
|
||||
)
|
||||
@@ -64,10 +66,14 @@ func GetFileSearchPricePerThousand() float64 {
|
||||
}
|
||||
|
||||
func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") {
|
||||
return Gemini25FlashPreviewInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") {
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") {
|
||||
return Gemini25FlashNativeAudioInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-lite") {
|
||||
return Gemini25FlashLitePreviewInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") {
|
||||
return Gemini25FlashPreviewInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash") {
|
||||
return Gemini25FlashProductionInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.0-flash") {
|
||||
return Gemini20FlashInputAudioPrice
|
||||
}
|
||||
|
||||
@@ -1,8 +1,45 @@
|
||||
package setting
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
var EpayId = ""
|
||||
var EpayKey = ""
|
||||
var Price = 7.3
|
||||
var MinTopUp = 1
|
||||
|
||||
var PayMethods = []map[string]string{
|
||||
{
|
||||
"name": "支付宝",
|
||||
"color": "rgba(var(--semi-blue-5), 1)",
|
||||
"type": "alipay",
|
||||
},
|
||||
{
|
||||
"name": "微信",
|
||||
"color": "rgba(var(--semi-green-5), 1)",
|
||||
"type": "wxpay",
|
||||
},
|
||||
}
|
||||
|
||||
func UpdatePayMethodsByJsonString(jsonString string) error {
|
||||
PayMethods = make([]map[string]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
}
|
||||
|
||||
func PayMethods2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(PayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func ContainsPayMethod(method string) bool {
|
||||
for _, payMethod := range PayMethods {
|
||||
if payMethod["type"] == method {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package operation_setting
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
|
||||
cacheRatioMapMutex.Lock()
|
||||
defer cacheRatioMapMutex.Unlock()
|
||||
cacheRatioMap = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
|
||||
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCacheRatio returns the cache ratio for a model
|
||||
@@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) {
|
||||
}
|
||||
return ratio, true
|
||||
}
|
||||
|
||||
func GetCacheRatioCopy() map[string]float64 {
|
||||
cacheRatioMapMutex.RLock()
|
||||
defer cacheRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(cacheRatioMap))
|
||||
for k, v := range cacheRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
17
setting/ratio_setting/expose_ratio.go
Normal file
17
setting/ratio_setting/expose_ratio.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ratio_setting
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
var exposeRatioEnabled atomic.Bool
|
||||
|
||||
func init() {
|
||||
exposeRatioEnabled.Store(false)
|
||||
}
|
||||
|
||||
func SetExposeRatioEnabled(enabled bool) {
|
||||
exposeRatioEnabled.Store(enabled)
|
||||
}
|
||||
|
||||
func IsExposeRatioEnabled() bool {
|
||||
return exposeRatioEnabled.Load()
|
||||
}
|
||||
55
setting/ratio_setting/exposed_cache.go
Normal file
55
setting/ratio_setting/exposed_cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const exposedDataTTL = 30 * time.Second
|
||||
|
||||
type exposedCache struct {
|
||||
data gin.H
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
exposedData atomic.Value
|
||||
rebuildMu sync.Mutex
|
||||
)
|
||||
|
||||
func InvalidateExposedDataCache() {
|
||||
exposedData.Store((*exposedCache)(nil))
|
||||
}
|
||||
|
||||
func cloneGinH(src gin.H) gin.H {
|
||||
dst := make(gin.H, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func GetExposedData() gin.H {
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
rebuildMu.Lock()
|
||||
defer rebuildMu.Unlock()
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
newData := gin.H{
|
||||
"model_ratio": GetModelRatioCopy(),
|
||||
"completion_ratio": GetCompletionRatioCopy(),
|
||||
"cache_ratio": GetCacheRatioCopy(),
|
||||
"model_price": GetModelPriceCopy(),
|
||||
}
|
||||
exposedData.Store(&exposedCache{
|
||||
data: newData,
|
||||
expiresAt: time.Now().Add(exposedDataTTL),
|
||||
})
|
||||
return cloneGinH(newData)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package setting
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -73,15 +73,15 @@ func GetGroupRatio(name string) float64 {
|
||||
return ratio
|
||||
}
|
||||
|
||||
func GetGroupGroupRatio(group, name string) (float64, bool) {
|
||||
func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
|
||||
groupGroupRatioMutex.RLock()
|
||||
defer groupGroupRatioMutex.RUnlock()
|
||||
|
||||
gp, ok := GroupGroupRatio[group]
|
||||
gp, ok := GroupGroupRatio[userGroup]
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
ratio, ok := gp[name]
|
||||
ratio, ok := gp[usingGroup]
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package operation_setting
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@@ -139,6 +140,7 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-2.0-flash": 0.05,
|
||||
"gemini-2.5-pro-exp-03-25": 0.625,
|
||||
"gemini-2.5-pro-preview-03-25": 0.625,
|
||||
"gemini-2.5-pro": 0.625,
|
||||
"gemini-2.5-flash-preview-04-17": 0.075,
|
||||
"gemini-2.5-flash-preview-04-17-thinking": 0.075,
|
||||
"gemini-2.5-flash-preview-04-17-nothinking": 0.075,
|
||||
@@ -147,6 +149,8 @@ var defaultModelRatio = map[string]float64{
|
||||
"gemini-2.5-flash-preview-05-20-nothinking": 0.075,
|
||||
"gemini-2.5-flash-thinking-*": 0.075, // 用于为后续所有2.5 flash thinking budget 模型设置默认倍率
|
||||
"gemini-2.5-pro-thinking-*": 0.625, // 用于为后续所有2.5 pro thinking budget 模型设置默认倍率
|
||||
"gemini-2.5-flash-lite-preview-06-17": 0.05,
|
||||
"gemini-2.5-flash": 0.15,
|
||||
"text-embedding-004": 0.001,
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
@@ -316,7 +320,11 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||
modelPriceMapMutex.Lock()
|
||||
defer modelPriceMapMutex.Unlock()
|
||||
modelPriceMap = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
|
||||
err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
||||
@@ -344,7 +352,11 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
modelRatioMapMutex.Lock()
|
||||
defer modelRatioMapMutex.Unlock()
|
||||
modelRatioMap = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
||||
err := json.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 处理带有思考预算的模型名称,方便统一定价
|
||||
@@ -366,7 +378,7 @@ func GetModelRatio(name string) (float64, bool) {
|
||||
}
|
||||
ratio, ok := modelRatioMap[name]
|
||||
if !ok {
|
||||
return 37.5, SelfUseModeEnabled
|
||||
return 37.5, operation_setting.SelfUseModeEnabled
|
||||
}
|
||||
return ratio, true
|
||||
}
|
||||
@@ -404,13 +416,22 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
||||
CompletionRatioMutex.Lock()
|
||||
defer CompletionRatioMutex.Unlock()
|
||||
CompletionRatio = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
||||
err := json.Unmarshal([]byte(jsonStr), &CompletionRatio)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
||||
name = "gpt-4o-gizmo-*"
|
||||
}
|
||||
if strings.Contains(name, "/") {
|
||||
if ratio, ok := CompletionRatio[name]; ok {
|
||||
return ratio
|
||||
@@ -428,12 +449,6 @@ func GetCompletionRatio(name string) float64 {
|
||||
|
||||
func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
lowercaseName := strings.ToLower(name)
|
||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||
name = "gpt-4-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4o-gizmo") {
|
||||
name = "gpt-4o-gizmo-*"
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") {
|
||||
if strings.HasPrefix(name, "gpt-4o") {
|
||||
if name == "gpt-4o-2024-05-13" {
|
||||
@@ -487,12 +502,17 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
return 4, true
|
||||
} else if strings.HasPrefix(name, "gemini-2.5-pro") { // 移除preview来增加兼容性,这里假设正式版的倍率和preview一致
|
||||
return 8, true
|
||||
} else if strings.HasPrefix(name, "gemini-2.5-flash") { // 同上
|
||||
if strings.HasSuffix(name, "-nothinking") {
|
||||
return 4, false
|
||||
} else {
|
||||
return 3.5 / 0.6, false
|
||||
} else if strings.HasPrefix(name, "gemini-2.5-flash") { // 处理不同的flash模型倍率
|
||||
if strings.HasPrefix(name, "gemini-2.5-flash-preview") {
|
||||
if strings.HasSuffix(name, "-nothinking") {
|
||||
return 4, true
|
||||
}
|
||||
return 3.5 / 0.15, true
|
||||
}
|
||||
if strings.HasPrefix(name, "gemini-2.5-flash-lite-preview") {
|
||||
return 4, true
|
||||
}
|
||||
return 2.5 / 0.3, true
|
||||
}
|
||||
return 4, false
|
||||
}
|
||||
@@ -608,3 +628,33 @@ func GetImageRatio(name string) (float64, bool) {
|
||||
}
|
||||
return ratio, true
|
||||
}
|
||||
|
||||
func GetModelRatioCopy() map[string]float64 {
|
||||
modelRatioMapMutex.RLock()
|
||||
defer modelRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func GetModelPriceCopy() map[string]float64 {
|
||||
modelPriceMapMutex.RLock()
|
||||
defer modelPriceMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
|
||||
func GetCompletionRatioCopy() map[string]float64 {
|
||||
CompletionRatioMutex.RLock()
|
||||
defer CompletionRatioMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(CompletionRatio))
|
||||
for k, v := range CompletionRatio {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
@@ -50,3 +50,10 @@ func GroupInUserUsableGroups(groupName string) bool {
|
||||
_, ok := userUsableGroups[groupName]
|
||||
return ok
|
||||
}
|
||||
|
||||
func GetUsableGroupDescription(groupName string) string {
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
return desc
|
||||
}
|
||||
return groupName
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
Tag,
|
||||
Typography,
|
||||
Skeleton,
|
||||
Badge,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { StatusContext } from '../../context/Status/index.js';
|
||||
import { useStyle, styleActions } from '../../context/Style/index.js';
|
||||
@@ -43,6 +44,7 @@ const HeaderBar = () => {
|
||||
const [mobileMenuOpen, setMobileMenuOpen] = useState(false);
|
||||
const location = useLocation();
|
||||
const [noticeVisible, setNoticeVisible] = useState(false);
|
||||
const [unreadCount, setUnreadCount] = useState(0);
|
||||
|
||||
const systemName = getSystemName();
|
||||
const logo = getLogo();
|
||||
@@ -53,9 +55,44 @@ const HeaderBar = () => {
|
||||
const docsLink = statusState?.status?.docs_link || '';
|
||||
const isDemoSiteMode = statusState?.status?.demo_site_enabled || false;
|
||||
|
||||
const isConsoleRoute = location.pathname.startsWith('/console');
|
||||
|
||||
const theme = useTheme();
|
||||
const setTheme = useSetTheme();
|
||||
|
||||
const announcements = statusState?.status?.announcements || [];
|
||||
|
||||
const getAnnouncementKey = (a) => `${a?.publishDate || ''}-${(a?.content || '').slice(0, 30)}`;
|
||||
|
||||
const calculateUnreadCount = () => {
|
||||
if (!announcements.length) return 0;
|
||||
let readKeys = [];
|
||||
try {
|
||||
readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || [];
|
||||
} catch (_) {
|
||||
readKeys = [];
|
||||
}
|
||||
const readSet = new Set(readKeys);
|
||||
return announcements.filter((a) => !readSet.has(getAnnouncementKey(a))).length;
|
||||
};
|
||||
|
||||
const getUnreadKeys = () => {
|
||||
if (!announcements.length) return [];
|
||||
let readKeys = [];
|
||||
try {
|
||||
readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || [];
|
||||
} catch (_) {
|
||||
readKeys = [];
|
||||
}
|
||||
const readSet = new Set(readKeys);
|
||||
return announcements.filter((a) => !readSet.has(getAnnouncementKey(a))).map(getAnnouncementKey);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
setUnreadCount(calculateUnreadCount());
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [announcements]);
|
||||
|
||||
const mainNavLinks = [
|
||||
{
|
||||
text: t('首页'),
|
||||
@@ -106,6 +143,25 @@ const HeaderBar = () => {
|
||||
}, 3000);
|
||||
};
|
||||
|
||||
const handleNoticeOpen = () => {
|
||||
setNoticeVisible(true);
|
||||
};
|
||||
|
||||
const handleNoticeClose = () => {
|
||||
setNoticeVisible(false);
|
||||
if (announcements.length) {
|
||||
let readKeys = [];
|
||||
try {
|
||||
readKeys = JSON.parse(localStorage.getItem('notice_read_keys')) || [];
|
||||
} catch (_) {
|
||||
readKeys = [];
|
||||
}
|
||||
const mergedKeys = Array.from(new Set([...readKeys, ...announcements.map(getAnnouncementKey)]));
|
||||
localStorage.setItem('notice_read_keys', JSON.stringify(mergedKeys));
|
||||
}
|
||||
setUnreadCount(0);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (theme === 'dark') {
|
||||
document.body.setAttribute('theme-mode', 'dark');
|
||||
@@ -353,15 +409,14 @@ const HeaderBar = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 检查当前路由是否以/console开头
|
||||
const isConsoleRoute = location.pathname.startsWith('/console');
|
||||
|
||||
return (
|
||||
<header className="text-semi-color-text-0 sticky top-0 z-50 transition-colors duration-300 bg-white/75 dark:bg-zinc-900/75 backdrop-blur-lg">
|
||||
<NoticeModal
|
||||
visible={noticeVisible}
|
||||
onClose={() => setNoticeVisible(false)}
|
||||
onClose={handleNoticeClose}
|
||||
isMobile={styleState.isMobile}
|
||||
defaultTab={unreadCount > 0 ? 'system' : 'inApp'}
|
||||
unreadKeys={getUnreadKeys()}
|
||||
/>
|
||||
<div className="w-full px-2">
|
||||
<div className="flex items-center justify-between h-16">
|
||||
@@ -462,14 +517,27 @@ const HeaderBar = () => {
|
||||
</Dropdown>
|
||||
)}
|
||||
|
||||
<Button
|
||||
icon={<IconBell className="text-lg" />}
|
||||
aria-label={t('系统公告')}
|
||||
onClick={() => setNoticeVisible(true)}
|
||||
theme="borderless"
|
||||
type="tertiary"
|
||||
className="!p-1.5 !text-current focus:!bg-semi-color-fill-1 dark:focus:!bg-gray-700 !rounded-full !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2"
|
||||
/>
|
||||
{unreadCount > 0 ? (
|
||||
<Badge count={unreadCount} type="danger" overflowCount={99}>
|
||||
<Button
|
||||
icon={<IconBell className="text-lg" />}
|
||||
aria-label={t('系统公告')}
|
||||
onClick={handleNoticeOpen}
|
||||
theme="borderless"
|
||||
type="tertiary"
|
||||
className="!p-1.5 !text-current focus:!bg-semi-color-fill-1 dark:focus:!bg-gray-700 !rounded-full !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2"
|
||||
/>
|
||||
</Badge>
|
||||
) : (
|
||||
<Button
|
||||
icon={<IconBell className="text-lg" />}
|
||||
aria-label={t('系统公告')}
|
||||
onClick={handleNoticeOpen}
|
||||
theme="borderless"
|
||||
type="tertiary"
|
||||
className="!p-1.5 !text-current focus:!bg-semi-color-fill-1 dark:focus:!bg-gray-700 !rounded-full !bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-semi-color-fill-2"
|
||||
/>
|
||||
)}
|
||||
|
||||
<Button
|
||||
icon={theme === 'dark' ? <IconSun size="large" className="text-yellow-500" /> : <IconMoon size="large" className="text-gray-300" />}
|
||||
|
||||
@@ -1,14 +1,36 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Button, Modal, Empty } from '@douyinfe/semi-ui';
|
||||
import React, { useEffect, useState, useContext, useMemo } from 'react';
|
||||
import { Button, Modal, Empty, Tabs, TabPane, Timeline } from '@douyinfe/semi-ui';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { API, showError } from '../../helpers';
|
||||
import { API, showError, getRelativeTime } from '../../helpers';
|
||||
import { marked } from 'marked';
|
||||
import { IllustrationNoContent, IllustrationNoContentDark } from '@douyinfe/semi-illustrations';
|
||||
import { StatusContext } from '../../context/Status/index.js';
|
||||
import { Bell, Megaphone } from 'lucide-react';
|
||||
|
||||
const NoticeModal = ({ visible, onClose, isMobile }) => {
|
||||
const NoticeModal = ({ visible, onClose, isMobile, defaultTab = 'inApp', unreadKeys = [] }) => {
|
||||
const { t } = useTranslation();
|
||||
const [noticeContent, setNoticeContent] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [activeTab, setActiveTab] = useState(defaultTab);
|
||||
|
||||
const [statusState] = useContext(StatusContext);
|
||||
|
||||
const announcements = statusState?.status?.announcements || [];
|
||||
|
||||
const unreadSet = useMemo(() => new Set(unreadKeys), [unreadKeys]);
|
||||
|
||||
const getKeyForItem = (item) => `${item?.publishDate || ''}-${(item?.content || '').slice(0, 30)}`;
|
||||
|
||||
const processedAnnouncements = useMemo(() => {
|
||||
return (announcements || []).slice(0, 20).map(item => ({
|
||||
key: getKeyForItem(item),
|
||||
type: item.type || 'default',
|
||||
time: getRelativeTime(item.publishDate),
|
||||
content: item.content,
|
||||
extra: item.extra,
|
||||
isUnread: unreadSet.has(getKeyForItem(item))
|
||||
}));
|
||||
}, [announcements, unreadSet]);
|
||||
|
||||
const handleCloseTodayNotice = () => {
|
||||
const today = new Date().toDateString();
|
||||
@@ -44,7 +66,13 @@ const NoticeModal = ({ visible, onClose, isMobile }) => {
|
||||
}
|
||||
}, [visible]);
|
||||
|
||||
const renderContent = () => {
|
||||
useEffect(() => {
|
||||
if (visible) {
|
||||
setActiveTab(defaultTab);
|
||||
}
|
||||
}, [defaultTab, visible]);
|
||||
|
||||
const renderMarkdownNotice = () => {
|
||||
if (loading) {
|
||||
return <div className="py-12"><Empty description={t('加载中...')} /></div>;
|
||||
}
|
||||
@@ -64,14 +92,80 @@ const NoticeModal = ({ visible, onClose, isMobile }) => {
|
||||
return (
|
||||
<div
|
||||
dangerouslySetInnerHTML={{ __html: noticeContent }}
|
||||
className="notice-content-scroll max-h-[60vh] overflow-y-auto pr-2"
|
||||
className="notice-content-scroll max-h-[55vh] overflow-y-auto pr-2"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const renderAnnouncementTimeline = () => {
|
||||
if (processedAnnouncements.length === 0) {
|
||||
return (
|
||||
<div className="py-12">
|
||||
<Empty
|
||||
image={<IllustrationNoContent style={{ width: 150, height: 150 }} />}
|
||||
darkModeImage={<IllustrationNoContentDark style={{ width: 150, height: 150 }} />}
|
||||
description={t('暂无系统公告')}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="max-h-[55vh] overflow-y-auto pr-2 card-content-scroll">
|
||||
<Timeline mode="alternate">
|
||||
{processedAnnouncements.map((item, idx) => {
|
||||
const htmlContent = marked.parse(item.content || '');
|
||||
const htmlExtra = item.extra ? marked.parse(item.extra) : '';
|
||||
return (
|
||||
<Timeline.Item
|
||||
key={idx}
|
||||
type={item.type}
|
||||
time={item.time}
|
||||
className={item.isUnread ? '' : ''}
|
||||
>
|
||||
<div>
|
||||
<div
|
||||
className={item.isUnread ? 'shine-text' : ''}
|
||||
dangerouslySetInnerHTML={{ __html: htmlContent }}
|
||||
/>
|
||||
{item.extra && (
|
||||
<div
|
||||
className="text-xs text-gray-500"
|
||||
dangerouslySetInnerHTML={{ __html: htmlExtra }}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Timeline.Item>
|
||||
);
|
||||
})}
|
||||
</Timeline>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderBody = () => {
|
||||
if (activeTab === 'inApp') {
|
||||
return renderMarkdownNotice();
|
||||
}
|
||||
return renderAnnouncementTimeline();
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('系统公告')}
|
||||
title={
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<span>{t('系统公告')}</span>
|
||||
<Tabs
|
||||
activeKey={activeTab}
|
||||
onChange={setActiveTab}
|
||||
type='card'
|
||||
size='small'
|
||||
>
|
||||
<TabPane tab={<span className="flex items-center gap-1"><Bell size={14} /> {t('通知')}</span>} itemKey='inApp' />
|
||||
<TabPane tab={<span className="flex items-center gap-1"><Megaphone size={14} /> {t('系统公告')}</span>} itemKey='system' />
|
||||
</Tabs>
|
||||
</div>
|
||||
}
|
||||
visible={visible}
|
||||
onCancel={onClose}
|
||||
footer={(
|
||||
@@ -82,7 +176,7 @@ const NoticeModal = ({ visible, onClose, isMobile }) => {
|
||||
)}
|
||||
size={isMobile ? 'full-width' : 'large'}
|
||||
>
|
||||
{renderContent()}
|
||||
{renderBody()}
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
236
web/src/components/settings/ChannelSelectorModal.js
Normal file
236
web/src/components/settings/ChannelSelectorModal.js
Normal file
@@ -0,0 +1,236 @@
|
||||
import React, { useState, useEffect, forwardRef, useImperativeHandle } from 'react';
|
||||
import { isMobile } from '../../helpers';
|
||||
import {
|
||||
Modal,
|
||||
Table,
|
||||
Input,
|
||||
Space,
|
||||
Highlight,
|
||||
Select,
|
||||
Tag,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { CheckCircle, XCircle, AlertCircle, HelpCircle } from 'lucide-react';
|
||||
|
||||
const ChannelSelectorModal = forwardRef(({
|
||||
visible,
|
||||
onCancel,
|
||||
onOk,
|
||||
allChannels,
|
||||
selectedChannelIds,
|
||||
setSelectedChannelIds,
|
||||
channelEndpoints,
|
||||
updateChannelEndpoint,
|
||||
t,
|
||||
}, ref) => {
|
||||
const [searchText, setSearchText] = useState('');
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const [pageSize, setPageSize] = useState(10);
|
||||
|
||||
const [filteredData, setFilteredData] = useState([]);
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
resetPagination: () => {
|
||||
setCurrentPage(1);
|
||||
setSearchText('');
|
||||
},
|
||||
}));
|
||||
|
||||
useEffect(() => {
|
||||
if (!allChannels) return;
|
||||
|
||||
const searchLower = searchText.trim().toLowerCase();
|
||||
const matched = searchLower
|
||||
? allChannels.filter((item) => {
|
||||
const name = (item.label || '').toLowerCase();
|
||||
const baseUrl = (item._originalData?.base_url || '').toLowerCase();
|
||||
return name.includes(searchLower) || baseUrl.includes(searchLower);
|
||||
})
|
||||
: allChannels;
|
||||
|
||||
setFilteredData(matched);
|
||||
}, [allChannels, searchText]);
|
||||
|
||||
const total = filteredData.length;
|
||||
|
||||
const paginatedData = filteredData.slice(
|
||||
(currentPage - 1) * pageSize,
|
||||
currentPage * pageSize,
|
||||
);
|
||||
|
||||
const updateEndpoint = (channelId, endpoint) => {
|
||||
if (typeof updateChannelEndpoint === 'function') {
|
||||
updateChannelEndpoint(channelId, endpoint);
|
||||
}
|
||||
};
|
||||
|
||||
const renderEndpointCell = (text, record) => {
|
||||
const channelId = record.key || record.value;
|
||||
const currentEndpoint = channelEndpoints[channelId] || '';
|
||||
|
||||
const getEndpointType = (ep) => {
|
||||
if (ep === '/api/ratio_config') return 'ratio_config';
|
||||
if (ep === '/api/pricing') return 'pricing';
|
||||
return 'custom';
|
||||
};
|
||||
|
||||
const currentType = getEndpointType(currentEndpoint);
|
||||
|
||||
const handleTypeChange = (val) => {
|
||||
if (val === 'ratio_config') {
|
||||
updateEndpoint(channelId, '/api/ratio_config');
|
||||
} else if (val === 'pricing') {
|
||||
updateEndpoint(channelId, '/api/pricing');
|
||||
} else {
|
||||
if (currentType !== 'custom') {
|
||||
updateEndpoint(channelId, '');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8 }}>
|
||||
<Select
|
||||
size="small"
|
||||
value={currentType}
|
||||
onChange={handleTypeChange}
|
||||
style={{ width: 120 }}
|
||||
optionList={[
|
||||
{ label: 'ratio_config', value: 'ratio_config' },
|
||||
{ label: 'pricing', value: 'pricing' },
|
||||
{ label: 'custom', value: 'custom' },
|
||||
]}
|
||||
/>
|
||||
{currentType === 'custom' && (
|
||||
<Input
|
||||
size="small"
|
||||
value={currentEndpoint}
|
||||
onChange={(val) => updateEndpoint(channelId, val)}
|
||||
placeholder="/your/endpoint"
|
||||
style={{ width: 160, fontSize: 12 }}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderStatusCell = (status) => {
|
||||
switch (status) {
|
||||
case 1:
|
||||
return (
|
||||
<Tag size='large' color='green' shape='circle' prefixIcon={<CheckCircle size={14} />}>
|
||||
{t('已启用')}
|
||||
</Tag>
|
||||
);
|
||||
case 2:
|
||||
return (
|
||||
<Tag size='large' color='red' shape='circle' prefixIcon={<XCircle size={14} />}>
|
||||
{t('已禁用')}
|
||||
</Tag>
|
||||
);
|
||||
case 3:
|
||||
return (
|
||||
<Tag size='large' color='yellow' shape='circle' prefixIcon={<AlertCircle size={14} />}>
|
||||
{t('自动禁用')}
|
||||
</Tag>
|
||||
);
|
||||
default:
|
||||
return (
|
||||
<Tag size='large' color='grey' shape='circle' prefixIcon={<HelpCircle size={14} />}>
|
||||
{t('未知状态')}
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const renderNameCell = (text) => (
|
||||
<Highlight sourceString={text} searchWords={[searchText]} />
|
||||
);
|
||||
|
||||
const renderBaseUrlCell = (text) => (
|
||||
<Highlight sourceString={text} searchWords={[searchText]} />
|
||||
);
|
||||
|
||||
const columns = [
|
||||
{
|
||||
title: t('名称'),
|
||||
dataIndex: 'label',
|
||||
render: renderNameCell,
|
||||
},
|
||||
{
|
||||
title: t('源地址'),
|
||||
dataIndex: '_originalData.base_url',
|
||||
render: (_, record) => renderBaseUrlCell(record._originalData?.base_url || ''),
|
||||
},
|
||||
{
|
||||
title: t('状态'),
|
||||
dataIndex: '_originalData.status',
|
||||
render: (_, record) => renderStatusCell(record._originalData?.status || 0),
|
||||
},
|
||||
{
|
||||
title: t('同步接口'),
|
||||
dataIndex: 'endpoint',
|
||||
fixed: 'right',
|
||||
render: renderEndpointCell,
|
||||
},
|
||||
];
|
||||
|
||||
const rowSelection = {
|
||||
selectedRowKeys: selectedChannelIds,
|
||||
onChange: (keys) => setSelectedChannelIds(keys),
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
visible={visible}
|
||||
onCancel={onCancel}
|
||||
onOk={onOk}
|
||||
title={<span className="text-lg font-semibold">{t('选择同步渠道')}</span>}
|
||||
size={isMobile() ? 'full-width' : 'large'}
|
||||
keepDOM
|
||||
lazyRender={false}
|
||||
>
|
||||
<Space vertical style={{ width: '100%' }}>
|
||||
<Input
|
||||
prefix={<IconSearch size={14} />}
|
||||
placeholder={t('搜索渠道名称或地址')}
|
||||
value={searchText}
|
||||
onChange={setSearchText}
|
||||
showClear
|
||||
className="!rounded-full"
|
||||
/>
|
||||
|
||||
<Table
|
||||
columns={columns}
|
||||
dataSource={paginatedData}
|
||||
rowKey="key"
|
||||
rowSelection={rowSelection}
|
||||
pagination={{
|
||||
currentPage: currentPage,
|
||||
pageSize: pageSize,
|
||||
total: total,
|
||||
showSizeChanger: true,
|
||||
showQuickJumper: true,
|
||||
pageSizeOptions: ['10', '20', '50', '100'],
|
||||
formatPageText: (page) => t('第 {{start}} - {{end}} 条,共 {{total}} 条', {
|
||||
start: page.currentStart,
|
||||
end: page.currentEnd,
|
||||
total: total,
|
||||
}),
|
||||
onChange: (page, size) => {
|
||||
setCurrentPage(page);
|
||||
setPageSize(size);
|
||||
},
|
||||
onShowSizeChange: (curr, size) => {
|
||||
setCurrentPage(1);
|
||||
setPageSize(size);
|
||||
},
|
||||
}}
|
||||
size="small"
|
||||
/>
|
||||
</Space>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
export default ChannelSelectorModal;
|
||||
63
web/src/components/settings/ChatsSetting.js
Normal file
63
web/src/components/settings/ChatsSetting.js
Normal file
@@ -0,0 +1,63 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Card, Spin } from '@douyinfe/semi-ui';
|
||||
import SettingsChats from '../../pages/Setting/Chat/SettingsChats.js';
|
||||
import { API, showError } from '../../helpers';
|
||||
|
||||
const ChatsSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
/* 聊天设置 */
|
||||
Chats: '[]',
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (
|
||||
item.key.endsWith('Enabled') ||
|
||||
['DefaultCollapseSidebar'].includes(item.key)
|
||||
) {
|
||||
newInputs[item.key] = item.value === 'true' ? true : false;
|
||||
} else {
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
});
|
||||
|
||||
setInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
async function onRefresh() {
|
||||
try {
|
||||
setLoading(true);
|
||||
await getOptions();
|
||||
} catch (error) {
|
||||
showError('刷新失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
onRefresh();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* 聊天设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsChats options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
</Spin>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default ChatsSetting;
|
||||
@@ -5,6 +5,7 @@ import SettingsAPIInfo from '../../pages/Setting/Dashboard/SettingsAPIInfo.js';
|
||||
import SettingsAnnouncements from '../../pages/Setting/Dashboard/SettingsAnnouncements.js';
|
||||
import SettingsFAQ from '../../pages/Setting/Dashboard/SettingsFAQ.js';
|
||||
import SettingsUptimeKuma from '../../pages/Setting/Dashboard/SettingsUptimeKuma.js';
|
||||
import SettingsDataDashboard from '../../pages/Setting/Dashboard/SettingsDataDashboard.js';
|
||||
|
||||
const DashboardSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
@@ -23,6 +24,11 @@ const DashboardSetting = () => {
|
||||
FAQ: '',
|
||||
UptimeKumaUrl: '',
|
||||
UptimeKumaSlug: '',
|
||||
|
||||
/* 数据看板 */
|
||||
DataExportEnabled: false,
|
||||
DataExportDefaultTime: 'hour',
|
||||
DataExportInterval: 5,
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
@@ -37,6 +43,10 @@ const DashboardSetting = () => {
|
||||
if (item.key in inputs) {
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
if (item.key.endsWith('Enabled') &&
|
||||
(item.key === 'DataExportEnabled')) {
|
||||
newInputs[item.key] = item.value === 'true' ? true : false;
|
||||
}
|
||||
});
|
||||
setInputs(newInputs);
|
||||
} else {
|
||||
@@ -106,9 +116,9 @@ const DashboardSetting = () => {
|
||||
</p>
|
||||
</Modal>
|
||||
|
||||
{/* API信息管理 */}
|
||||
{/* 数据看板设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsAPIInfo options={inputs} refresh={onRefresh} />
|
||||
<SettingsDataDashboard options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
|
||||
{/* 系统公告管理 */}
|
||||
@@ -116,6 +126,11 @@ const DashboardSetting = () => {
|
||||
<SettingsAnnouncements options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
|
||||
{/* API信息管理 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsAPIInfo options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
|
||||
{/* 常见问答管理 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsFAQ options={inputs} refresh={onRefresh} />
|
||||
|
||||
65
web/src/components/settings/DrawingSetting.js
Normal file
65
web/src/components/settings/DrawingSetting.js
Normal file
@@ -0,0 +1,65 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Card, Spin } from '@douyinfe/semi-ui';
|
||||
import SettingsDrawing from '../../pages/Setting/Drawing/SettingsDrawing.js';
|
||||
import { API, showError } from '../../helpers';
|
||||
|
||||
const DrawingSetting = () => {
|
||||
let [inputs, setInputs] = useState({
|
||||
/* 绘图设置 */
|
||||
DrawingEnabled: false,
|
||||
MjNotifyEnabled: false,
|
||||
MjAccountFilterEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
});
|
||||
|
||||
let [loading, setLoading] = useState(false);
|
||||
|
||||
const getOptions = async () => {
|
||||
const res = await API.get('/api/option/');
|
||||
const { success, message, data } = res.data;
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (item.key.endsWith('Enabled')) {
|
||||
newInputs[item.key] = item.value === 'true' ? true : false;
|
||||
} else {
|
||||
newInputs[item.key] = item.value;
|
||||
}
|
||||
});
|
||||
|
||||
setInputs(newInputs);
|
||||
} else {
|
||||
showError(message);
|
||||
}
|
||||
};
|
||||
|
||||
async function onRefresh() {
|
||||
try {
|
||||
setLoading(true);
|
||||
await getOptions();
|
||||
} catch (error) {
|
||||
showError('刷新失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
onRefresh();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Spin spinning={loading} size='large'>
|
||||
{/* 绘图设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsDrawing options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
</Spin>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default DrawingSetting;
|
||||
@@ -1,66 +1,44 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
|
||||
import { Card, Spin } from '@douyinfe/semi-ui';
|
||||
import SettingsGeneral from '../../pages/Setting/Operation/SettingsGeneral.js';
|
||||
import SettingsDrawing from '../../pages/Setting/Operation/SettingsDrawing.js';
|
||||
import SettingsSensitiveWords from '../../pages/Setting/Operation/SettingsSensitiveWords.js';
|
||||
import SettingsLog from '../../pages/Setting/Operation/SettingsLog.js';
|
||||
import SettingsDataDashboard from '../../pages/Setting/Operation/SettingsDataDashboard.js';
|
||||
import SettingsMonitoring from '../../pages/Setting/Operation/SettingsMonitoring.js';
|
||||
import SettingsCreditLimit from '../../pages/Setting/Operation/SettingsCreditLimit.js';
|
||||
import ModelSettingsVisualEditor from '../../pages/Setting/Operation/ModelSettingsVisualEditor.js';
|
||||
import GroupRatioSettings from '../../pages/Setting/Operation/GroupRatioSettings.js';
|
||||
import ModelRatioSettings from '../../pages/Setting/Operation/ModelRatioSettings.js';
|
||||
|
||||
import { API, showError, showSuccess } from '../../helpers';
|
||||
import SettingsChats from '../../pages/Setting/Operation/SettingsChats.js';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ModelRatioNotSetEditor from '../../pages/Setting/Operation/ModelRationNotSetEditor.js';
|
||||
import { API, showError } from '../../helpers';
|
||||
|
||||
const OperationSetting = () => {
|
||||
const { t } = useTranslation();
|
||||
let [inputs, setInputs] = useState({
|
||||
/* 额度相关 */
|
||||
QuotaForNewUser: 0,
|
||||
PreConsumedQuota: 0,
|
||||
QuotaForInviter: 0,
|
||||
QuotaForInvitee: 0,
|
||||
QuotaRemindThreshold: 0,
|
||||
PreConsumedQuota: 0,
|
||||
StreamCacheQueueLength: 0,
|
||||
ModelRatio: '',
|
||||
CacheRatio: '',
|
||||
CompletionRatio: '',
|
||||
ModelPrice: '',
|
||||
GroupRatio: '',
|
||||
GroupGroupRatio: '',
|
||||
UserUsableGroups: '',
|
||||
|
||||
/* 通用设置 */
|
||||
TopUpLink: '',
|
||||
'general_setting.docs_link': '',
|
||||
// ChatLink2: '', // 添加的新状态变量
|
||||
QuotaPerUnit: 0,
|
||||
AutomaticDisableChannelEnabled: false,
|
||||
AutomaticEnableChannelEnabled: false,
|
||||
ChannelDisableThreshold: 0,
|
||||
LogConsumeEnabled: false,
|
||||
RetryTimes: 0,
|
||||
DisplayInCurrencyEnabled: false,
|
||||
DisplayTokenStatEnabled: false,
|
||||
CheckSensitiveEnabled: false,
|
||||
CheckSensitiveOnPromptEnabled: false,
|
||||
CheckSensitiveOnCompletionEnabled: '',
|
||||
StopOnSensitiveEnabled: '',
|
||||
SensitiveWords: '',
|
||||
MjNotifyEnabled: false,
|
||||
MjAccountFilterEnabled: false,
|
||||
MjModeClearEnabled: false,
|
||||
MjForwardUrlEnabled: false,
|
||||
MjActionCheckSuccessEnabled: false,
|
||||
DrawingEnabled: false,
|
||||
DataExportEnabled: false,
|
||||
DataExportDefaultTime: 'hour',
|
||||
DataExportInterval: 5,
|
||||
DefaultCollapseSidebar: false, // 默认折叠侧边栏
|
||||
RetryTimes: 0,
|
||||
Chats: '[]',
|
||||
DefaultCollapseSidebar: false,
|
||||
DemoSiteEnabled: false,
|
||||
SelfUseModeEnabled: false,
|
||||
|
||||
/* 敏感词设置 */
|
||||
CheckSensitiveEnabled: false,
|
||||
CheckSensitiveOnPromptEnabled: false,
|
||||
SensitiveWords: '',
|
||||
|
||||
/* 日志设置 */
|
||||
LogConsumeEnabled: false,
|
||||
|
||||
/* 监控设置 */
|
||||
ChannelDisableThreshold: 0,
|
||||
QuotaRemindThreshold: 0,
|
||||
AutomaticDisableChannelEnabled: false,
|
||||
AutomaticEnableChannelEnabled: false,
|
||||
AutomaticDisableKeywords: '',
|
||||
});
|
||||
|
||||
@@ -72,17 +50,6 @@ const OperationSetting = () => {
|
||||
if (success) {
|
||||
let newInputs = {};
|
||||
data.forEach((item) => {
|
||||
if (
|
||||
item.key === 'ModelRatio' ||
|
||||
item.key === 'GroupRatio' ||
|
||||
item.key === 'GroupGroupRatio' ||
|
||||
item.key === 'UserUsableGroups' ||
|
||||
item.key === 'CompletionRatio' ||
|
||||
item.key === 'ModelPrice' ||
|
||||
item.key === 'CacheRatio'
|
||||
) {
|
||||
item.value = JSON.stringify(JSON.parse(item.value), null, 2);
|
||||
}
|
||||
if (
|
||||
item.key.endsWith('Enabled') ||
|
||||
['DefaultCollapseSidebar'].includes(item.key)
|
||||
@@ -121,10 +88,6 @@ const OperationSetting = () => {
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsGeneral options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 绘图设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsDrawing options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 屏蔽词过滤设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsSensitiveWords options={inputs} refresh={onRefresh} />
|
||||
@@ -133,10 +96,6 @@ const OperationSetting = () => {
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsLog options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 数据看板 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsDataDashboard options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 监控设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsMonitoring options={inputs} refresh={onRefresh} />
|
||||
@@ -145,28 +104,6 @@ const OperationSetting = () => {
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsCreditLimit options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 聊天设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<SettingsChats options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 分组倍率设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<GroupRatioSettings options={inputs} refresh={onRefresh} />
|
||||
</Card>
|
||||
{/* 合并模型倍率设置和可视化倍率设置 */}
|
||||
<Card style={{ marginTop: '10px' }}>
|
||||
<Tabs type='line'>
|
||||
<Tabs.TabPane tab={t('模型倍率设置')} itemKey='model'>
|
||||
<ModelRatioSettings options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('可视化倍率设置')} itemKey='visual'>
|
||||
<ModelSettingsVisualEditor options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
<Tabs.TabPane tab={t('未设置倍率模型')} itemKey='unset_models'>
|
||||
<ModelRatioNotSetEditor options={inputs} refresh={onRefresh} />
|
||||
</Tabs.TabPane>
|
||||
</Tabs>
|
||||
</Card>
|
||||
</Spin>
|
||||
</>
|
||||
);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user