mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-05 06:27:19 +00:00
Compare commits
197 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
647f8d7958 | ||
|
|
5d289d38ba | ||
|
|
05ea0dd54f | ||
|
|
1dad04ec09 | ||
|
|
2171117c53 | ||
|
|
3ced5ff144 | ||
|
|
38d3ab5acf | ||
|
|
ab32e15a86 | ||
|
|
25e17b95d5 | ||
|
|
d07224e658 | ||
|
|
aa15d45a3d | ||
|
|
1a0aac81df | ||
|
|
39cb45c11c | ||
|
|
05d9aa53ef | ||
|
|
86f374df58 | ||
|
|
6935260bf0 | ||
|
|
f0d888729b | ||
|
|
6d7d4292ef | ||
|
|
fcefac9dbe | ||
|
|
ad5f731b20 | ||
|
|
0689670698 | ||
|
|
5a6f32c392 | ||
|
|
d6276c4692 | ||
|
|
29a44eb7ae | ||
|
|
048a625181 | ||
|
|
64782027c4 | ||
|
|
277645db50 | ||
|
|
3f53e4f53e | ||
|
|
0c5d4ca0a7 | ||
|
|
44495b153a | ||
|
|
de6e551cdb | ||
|
|
aeb393e391 | ||
|
|
db1b11deaf | ||
|
|
5a5e8ce652 | ||
|
|
6c31151430 | ||
|
|
a8ba2eba33 | ||
|
|
c974b1053c | ||
|
|
1ab75b8a92 | ||
|
|
75e3959474 | ||
|
|
bc371778b6 | ||
|
|
cd2870aebc | ||
|
|
7c72545217 | ||
|
|
2591ca3d60 | ||
|
|
c28190316f | ||
|
|
ffc22b8dac | ||
|
|
5367015a31 | ||
|
|
75c71c397e | ||
|
|
6192aebe66 | ||
|
|
a85a594597 | ||
|
|
014c9450ba | ||
|
|
63640f65e8 | ||
|
|
fd040988a3 | ||
|
|
f7c3b043b5 | ||
|
|
93e7675bc3 | ||
|
|
d7c97d4d34 | ||
|
|
dce794dbf7 | ||
|
|
093d86040f | ||
|
|
39617bc8c6 | ||
|
|
7da224ba92 | ||
|
|
df862732df | ||
|
|
fd4447f60a | ||
|
|
ea79d59aa0 | ||
|
|
41b0cf406c | ||
|
|
ef32cc8e0a | ||
|
|
ee8956b0e9 | ||
|
|
5ad9f8d931 | ||
|
|
ea379e1d0e | ||
|
|
b842baf21f | ||
|
|
58c9c7d5dd | ||
|
|
384fadf227 | ||
|
|
e4def0625b | ||
|
|
44d20de251 | ||
|
|
7ea33c2ddf | ||
|
|
b43423bffc | ||
|
|
cf4700a35c | ||
|
|
6bb552128c | ||
|
|
50b4fc06f8 | ||
|
|
f7f1be9df2 | ||
|
|
59574dc80f | ||
|
|
7577ec1ac4 | ||
|
|
d487be0029 | ||
|
|
83a3872b97 | ||
|
|
1ad2f63f85 | ||
|
|
fcaa8317e4 | ||
|
|
ccda14255a | ||
|
|
8d66828229 | ||
|
|
4ebf9e35e1 | ||
|
|
2902d6c7c2 | ||
|
|
01ef1fe4e4 | ||
|
|
c3d2d07b68 | ||
|
|
18417bacb3 | ||
|
|
8ec18dd21b | ||
|
|
edaff1c689 | ||
|
|
9c3a13cb23 | ||
|
|
0b326e7af4 | ||
|
|
1a1ff836b5 | ||
|
|
34fed74f64 | ||
|
|
f89b29928c | ||
|
|
2c6d4460c3 | ||
|
|
7afd3f97ee | ||
|
|
0708452939 | ||
|
|
a9e5d99ea3 | ||
|
|
a56d9ea98b | ||
|
|
f5e80af0b3 | ||
|
|
a1a7ddbc83 | ||
|
|
8b209d8926 | ||
|
|
9344cab59a | ||
|
|
03468e05e4 | ||
|
|
11792ba1a4 | ||
|
|
5baaa06896 | ||
|
|
d3286893c4 | ||
|
|
b087b20bac | ||
|
|
6a5a839d4d | ||
|
|
5d8a0952b4 | ||
|
|
bd08ecc1e0 | ||
|
|
e4f61c1084 | ||
|
|
a38215478f | ||
|
|
c192d07a04 | ||
|
|
098880b796 | ||
|
|
150c506ece | ||
|
|
f978d8224e | ||
|
|
ab59887933 | ||
|
|
458472f3e2 | ||
|
|
a9f98c5d39 | ||
|
|
2b7dff2d94 | ||
|
|
58752d2dcf | ||
|
|
67546f4b2a | ||
|
|
8e9dae7b5f | ||
|
|
fb4ff63bad | ||
|
|
1fed1ee567 | ||
|
|
02571c20ff | ||
|
|
8201daa4b4 | ||
|
|
5b54624cd5 | ||
|
|
db737567fb | ||
|
|
5c3898d13e | ||
|
|
2fdb2be6d0 | ||
|
|
ab78efc815 | ||
|
|
fcf97d1796 | ||
|
|
e85cc6acbe | ||
|
|
da002e6ca9 | ||
|
|
070e7b6911 | ||
|
|
d5a3eb7d04 | ||
|
|
616e6953cc | ||
|
|
b7c77777a5 | ||
|
|
8a79de333a | ||
|
|
a87d4271d3 | ||
|
|
7975cdf3bf | ||
|
|
b2badad554 | ||
|
|
133d8c9f77 | ||
|
|
9708d645d3 | ||
|
|
0bca4d3efc | ||
|
|
7572e791f6 | ||
|
|
16c63b3be9 | ||
|
|
37fbcb7950 | ||
|
|
a180d13182 | ||
|
|
a6363a502a | ||
|
|
81bc096872 | ||
|
|
edcdb378fd | ||
|
|
4447e51588 | ||
|
|
fb8aac650f | ||
|
|
ba6b0637cc | ||
|
|
3502730dfc | ||
|
|
b95c5bb8f4 | ||
|
|
f35784aa97 | ||
|
|
3746482e8c | ||
|
|
5ed4b60b8f | ||
|
|
547da2da60 | ||
|
|
f88ed4dd5c | ||
|
|
87fc681df3 | ||
|
|
a39b2f5aa7 | ||
|
|
0b9b21eafd | ||
|
|
21f43b0dd8 | ||
|
|
3a7ba5725c | ||
|
|
2e4fa32d63 | ||
|
|
0199896d9a | ||
|
|
edd9049100 | ||
|
|
290c763901 | ||
|
|
226446a3b5 | ||
|
|
ab627db4be | ||
|
|
0f35d2368f | ||
|
|
3c276d13c4 | ||
|
|
b7c3328d43 | ||
|
|
4d8e63bd1a | ||
|
|
7fa21ce95f | ||
|
|
296da5dbcc | ||
|
|
1ec2bbd533 | ||
|
|
d67d5d8006 | ||
|
|
c4f25a77d1 | ||
|
|
52763c09f2 | ||
|
|
e77555a04f | ||
|
|
4a313a5f93 | ||
|
|
3d9587f128 | ||
|
|
66778efcc5 | ||
|
|
6be78ff283 | ||
|
|
5281f2ba64 | ||
|
|
69420f713f | ||
|
|
bc322ddac4 |
@@ -59,7 +59,7 @@
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
# 设置流式一次回复的超时时间
|
||||
# STREAMING_TIMEOUT=90
|
||||
# STREAMING_TIMEOUT=120
|
||||
|
||||
|
||||
# 节点类型
|
||||
|
||||
@@ -24,8 +24,7 @@ RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)'" -o one-
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk update \
|
||||
&& apk upgrade \
|
||||
RUN apk upgrade --no-cache \
|
||||
&& apk add --no-cache ca-certificates tzdata ffmpeg \
|
||||
&& update-ca-certificates
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ This version supports multiple models, please refer to [API Documentation-Relay
|
||||
For detailed configuration instructions, please refer to [Installation Guide-Environment Variables Configuration](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`: Whether to generate initial tokens for newly registered users, default is `false`
|
||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 60 seconds
|
||||
- `STREAMING_TIMEOUT`: Streaming response timeout, default is 120 seconds
|
||||
- `DIFY_DEBUG`: Whether to output workflow and node information for Dify channels, default is `true`
|
||||
- `FORCE_STREAM_OPTION`: Whether to override client stream_options parameter, default is `true`
|
||||
- `GET_MEDIA_TOKEN`: Whether to count image tokens, default is `true`
|
||||
|
||||
@@ -103,7 +103,7 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
详细配置说明请参考[安装指南-环境变量配置](https://docs.newapi.pro/installation/environment-variables):
|
||||
|
||||
- `GENERATE_DEFAULT_TOKEN`:是否为新注册用户生成初始令牌,默认为 `false`
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认60秒
|
||||
- `STREAMING_TIMEOUT`:流式回复超时时间,默认120秒
|
||||
- `DIFY_DEBUG`:Dify渠道是否输出工作流和节点信息,默认 `true`
|
||||
- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,默认 `true`
|
||||
- `GET_MEDIA_TOKEN`:是否统计图片token,默认 `true`
|
||||
|
||||
@@ -241,6 +241,8 @@ const (
|
||||
ChannelTypeXinference = 47
|
||||
ChannelTypeXai = 48
|
||||
ChannelTypeCoze = 49
|
||||
ChannelTypeKling = 50
|
||||
ChannelTypeJimeng = 51
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
)
|
||||
@@ -296,4 +298,6 @@ var ChannelBaseURLs = []string{
|
||||
"", //47
|
||||
"https://api.x.ai", //48
|
||||
"https://api.coze.cn", //49
|
||||
"https://api.klingai.com", //50
|
||||
"https://visual.volcengineapi.com", //51
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
@@ -284,3 +285,20 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
// BuildURL concatenates base and endpoint, returns the complete url string
|
||||
func BuildURL(base string, endpoint string) string {
|
||||
u, err := url.Parse(base)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
end := endpoint
|
||||
if end == "" {
|
||||
end = "/"
|
||||
}
|
||||
ref, err := url.Parse(end)
|
||||
if err != nil {
|
||||
return base + endpoint
|
||||
}
|
||||
return u.ResolveReference(ref).String()
|
||||
}
|
||||
|
||||
@@ -7,4 +7,5 @@ const (
|
||||
ContextKeyUserStatus = "user_status"
|
||||
ContextKeyUserEmail = "user_email"
|
||||
ContextKeyUserGroup = "user_group"
|
||||
ContextKeyUsingGroup = "group"
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ var ErrorLogEnabled bool
|
||||
//}
|
||||
|
||||
func InitEnv() {
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
|
||||
StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 120)
|
||||
DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true)
|
||||
MaxFileDownloadMB = common.GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
|
||||
@@ -5,6 +5,8 @@ type TaskPlatform string
|
||||
const (
|
||||
TaskPlatformSuno TaskPlatform = "suno"
|
||||
TaskPlatformMidjourney = "mj"
|
||||
TaskPlatformKling TaskPlatform = "kling"
|
||||
TaskPlatformJimeng TaskPlatform = "jimeng"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/shopspring/decimal"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -304,6 +306,40 @@ func updateChannelOpenRouterBalance(channel *model.Channel) (float64, error) {
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
url := "https://api.moonshot.cn/v1/users/me/balance"
|
||||
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
type MoonshotBalanceData struct {
|
||||
AvailableBalance float64 `json:"available_balance"`
|
||||
VoucherBalance float64 `json:"voucher_balance"`
|
||||
CashBalance float64 `json:"cash_balance"`
|
||||
}
|
||||
|
||||
type MoonshotBalanceResponse struct {
|
||||
Code int `json:"code"`
|
||||
Data MoonshotBalanceData `json:"data"`
|
||||
Scode string `json:"scode"`
|
||||
Status bool `json:"status"`
|
||||
}
|
||||
|
||||
response := MoonshotBalanceResponse{}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !response.Status || response.Code != 0 {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() == "" {
|
||||
@@ -332,6 +368,8 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return updateChannelDeepSeekBalance(channel)
|
||||
case common.ChannelTypeOpenRouter:
|
||||
return updateChannelOpenRouterBalance(channel)
|
||||
case common.ChannelTypeMoonshot:
|
||||
return updateChannelMoonshotBalance(channel)
|
||||
default:
|
||||
return 0, errors.New("尚未实现")
|
||||
}
|
||||
|
||||
@@ -40,6 +40,12 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
if channel.Type == common.ChannelTypeSunoAPI {
|
||||
return errors.New("suno channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeKling {
|
||||
return errors.New("kling channel test is not supported"), nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeJimeng {
|
||||
return errors.New("jimeng channel test is not supported"), nil
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
@@ -90,7 +96,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
|
||||
info := relaycommon.GenRelayInfo(c)
|
||||
|
||||
err = helper.ModelMappedHelper(c, info)
|
||||
err = helper.ModelMappedHelper(c, info, nil)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
@@ -165,10 +171,10 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
consumedTime := float64(milliseconds) / 1000.0
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.UserGroupRatio)
|
||||
other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
||||
usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试",
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other)
|
||||
quota, "模型测试", 0, quota, int(consumedTime), false, info.UsingGroup, other)
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
@@ -312,7 +318,7 @@ func testAllChannels(notify bool) error {
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
||||
|
||||
if notify {
|
||||
service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
||||
}
|
||||
|
||||
@@ -40,6 +40,17 @@ type OpenAIModelsResponse struct {
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
func parseStatusFilter(statusParam string) int {
|
||||
switch strings.ToLower(statusParam) {
|
||||
case "enabled", "1":
|
||||
return common.ChannelStatusEnabled
|
||||
case "disabled", "0":
|
||||
return 0
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllChannels(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
pageSize, _ := strconv.Atoi(c.Query("page_size"))
|
||||
@@ -52,44 +63,100 @@ func GetAllChannels(c *gin.Context) {
|
||||
channelData := make([]*model.Channel, 0)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
statusParam := c.Query("status")
|
||||
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
// type filter
|
||||
typeStr := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeStr != "" {
|
||||
if t, err := strconv.Atoi(typeStr); err == nil {
|
||||
typeFilter = t
|
||||
}
|
||||
}
|
||||
|
||||
var total int64
|
||||
|
||||
if enableTagMode {
|
||||
// tag 分页:先分页 tag,再取各 tag 下 channels
|
||||
tags, err := model.GetPaginatedTags((p-1)*pageSize, pageSize)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag != nil && *tag != "" {
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err == nil {
|
||||
channelData = append(channelData, tagChannel...)
|
||||
}
|
||||
if tag == nil || *tag == "" {
|
||||
continue
|
||||
}
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
filtered := make([]*model.Channel, 0)
|
||||
for _, ch := range tagChannels {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if typeFilter >= 0 && ch.Type != typeFilter {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = append(channelData, filtered...)
|
||||
}
|
||||
// 计算 tag 总数用于分页
|
||||
total, _ = model.CountAllTags()
|
||||
} else {
|
||||
channels, err := model.GetAllChannels((p-1)*pageSize, pageSize, false, idSort)
|
||||
baseQuery := model.DB.Model(&model.Channel{})
|
||||
if typeFilter >= 0 {
|
||||
baseQuery = baseQuery.Where("type = ?", typeFilter)
|
||||
}
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
|
||||
baseQuery.Count(&total)
|
||||
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
|
||||
err := baseQuery.Order(order).Limit(pageSize).Offset((p-1)*pageSize).Omit("key").Find(&channelData).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
channelData = channels
|
||||
total, _ = model.CountAllChannels()
|
||||
}
|
||||
|
||||
countQuery := model.DB.Model(&model.Channel{})
|
||||
if statusFilter == common.ChannelStatusEnabled {
|
||||
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
|
||||
} else if statusFilter == 0 {
|
||||
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
|
||||
}
|
||||
var results []struct {
|
||||
Type int64
|
||||
Count int64
|
||||
}
|
||||
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
typeCounts[r.Type] = r.Count
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": gin.H{
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"items": channelData,
|
||||
"total": total,
|
||||
"page": p,
|
||||
"page_size": pageSize,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -114,13 +181,6 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
//if channel.Type != common.ChannelTypeOpenAI {
|
||||
// c.JSON(http.StatusOK, gin.H{
|
||||
// "success": false,
|
||||
// "message": "仅支持 OpenAI 类型渠道",
|
||||
// })
|
||||
// return
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
@@ -186,6 +246,8 @@ func SearchChannels(c *gin.Context) {
|
||||
keyword := c.Query("keyword")
|
||||
group := c.Query("group")
|
||||
modelKeyword := c.Query("model")
|
||||
statusParam := c.Query("status")
|
||||
statusFilter := parseStatusFilter(statusParam)
|
||||
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
|
||||
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
|
||||
channelData := make([]*model.Channel, 0)
|
||||
@@ -217,10 +279,74 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
channelData = channels
|
||||
}
|
||||
|
||||
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
// calculate type counts for search results
|
||||
typeCounts := make(map[int64]int64)
|
||||
for _, channel := range channelData {
|
||||
typeCounts[int64(channel.Type)]++
|
||||
}
|
||||
|
||||
typeParam := c.Query("type")
|
||||
typeFilter := -1
|
||||
if typeParam != "" {
|
||||
if tp, err := strconv.Atoi(typeParam); err == nil {
|
||||
typeFilter = tp
|
||||
}
|
||||
}
|
||||
|
||||
if typeFilter >= 0 {
|
||||
filtered := make([]*model.Channel, 0, len(channelData))
|
||||
for _, ch := range channelData {
|
||||
if ch.Type == typeFilter {
|
||||
filtered = append(filtered, ch)
|
||||
}
|
||||
}
|
||||
channelData = filtered
|
||||
}
|
||||
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
total := len(channelData)
|
||||
startIdx := (page - 1) * pageSize
|
||||
if startIdx > total {
|
||||
startIdx = total
|
||||
}
|
||||
endIdx := startIdx + pageSize
|
||||
if endIdx > total {
|
||||
endIdx = total
|
||||
}
|
||||
|
||||
pagedData := channelData[startIdx:endIdx]
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": channelData,
|
||||
"data": gin.H{
|
||||
"items": pagedData,
|
||||
"total": total,
|
||||
"type_counts": typeCounts,
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -516,6 +642,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
channel.Key = ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetGroups(c *gin.Context) {
|
||||
groupNames := make([]string, 0)
|
||||
for groupName, _ := range setting.GetGroupRatioCopy() {
|
||||
for groupName := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -24,7 +26,7 @@ func GetUserGroups(c *gin.Context) {
|
||||
userGroup := ""
|
||||
userId := c.GetInt("id")
|
||||
userGroup, _ = model.GetUserGroup(userId, false)
|
||||
for groupName, ratio := range setting.GetGroupRatioCopy() {
|
||||
for groupName, ratio := range ratio_setting.GetGroupRatioCopy() {
|
||||
// UserUsableGroups contains the groups that the user can use
|
||||
userUsableGroups := setting.GetUserUsableGroups(userGroup)
|
||||
if desc, ok := userUsableGroups[groupName]; ok {
|
||||
@@ -34,6 +36,12 @@ func GetUserGroups(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
if setting.GroupInUserUsableGroups("auto") {
|
||||
usableGroups["auto"] = map[string]interface{}{
|
||||
"ratio": "自动",
|
||||
"desc": setting.GetUsableGroupDescription("auto"),
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/setting/console_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -41,46 +41,48 @@ func GetStatus(c *gin.Context) {
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"version": common.Version,
|
||||
"start_time": common.StartTime,
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"telegram_oauth": common.TelegramOAuthEnabled,
|
||||
"telegram_bot_name": common.TelegramBotName,
|
||||
"system_name": common.SystemName,
|
||||
"logo": common.Logo,
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
"docs_link": operation_setting.GetGeneralSetting().DocsLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"enable_batch_update": common.BatchUpdateEnabled,
|
||||
"enable_drawing": common.DrawingEnabled,
|
||||
"enable_task": common.TaskEnabled,
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
"uptime_kuma_enabled": cs.UptimeKumaEnabled,
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
|
||||
@@ -2,7 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
"one-api/relay/channel/moonshot"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -134,6 +137,9 @@ func init() {
|
||||
adaptor.Init(meta)
|
||||
channelId2Models[i] = adaptor.GetModelList()
|
||||
}
|
||||
openAIModels = lo.UniqBy(openAIModels, func(m dto.OpenAIModels) string {
|
||||
return m.Id
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
@@ -179,7 +185,19 @@ func ListModels(c *gin.Context) {
|
||||
if tokenGroup != "" {
|
||||
group = tokenGroup
|
||||
}
|
||||
models := model.GetGroupModels(group)
|
||||
var models []string
|
||||
if tokenGroup == "auto" {
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
groupModels := model.GetGroupModels(autoGroup)
|
||||
for _, g := range groupModels {
|
||||
if !common.StringsContains(models, g) {
|
||||
models = append(models, g)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
models = model.GetGroupModels(group)
|
||||
}
|
||||
for _, s := range models {
|
||||
if _, ok := openAIModelsMap[s]; ok {
|
||||
userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s])
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/console_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strings"
|
||||
|
||||
@@ -103,7 +104,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -3,7 +3,6 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
@@ -13,6 +12,8 @@ import (
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Playground(c *gin.Context) {
|
||||
@@ -57,13 +58,22 @@ func Playground(c *gin.Context) {
|
||||
c.Set("group", group)
|
||||
}
|
||||
c.Set("token_name", "playground-"+group)
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, playgroundRequest.Model, 0)
|
||||
channel, finalGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, playgroundRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, playgroundRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", finalGroup, playgroundRequest.Model)
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(errors.New(message), "get_playground_channel_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
middleware.SetupContextForSelectedChannel(c, channel, playgroundRequest.Model)
|
||||
c.Set(constant.ContextKeyRequestStartTime, time.Now())
|
||||
|
||||
// Write user context to ensure acceptUnsetRatio is available
|
||||
userId := c.GetInt("id")
|
||||
userCache, err := model.GetUserCache(userId)
|
||||
if err != nil {
|
||||
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_user_cache_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
userCache.WriteContext(c)
|
||||
Relay(c)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package controller
|
||||
import (
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -13,7 +13,7 @@ func GetPricing(c *gin.Context) {
|
||||
userId, exists := c.Get("id")
|
||||
usableGroup := map[string]string{}
|
||||
groupRatio := map[string]float64{}
|
||||
for s, f := range setting.GetGroupRatioCopy() {
|
||||
for s, f := range ratio_setting.GetGroupRatioCopy() {
|
||||
groupRatio[s] = f
|
||||
}
|
||||
var group string
|
||||
@@ -22,7 +22,7 @@ func GetPricing(c *gin.Context) {
|
||||
if err == nil {
|
||||
group = user.Group
|
||||
for g := range groupRatio {
|
||||
ratio, ok := setting.GetGroupGroupRatio(group, g)
|
||||
ratio, ok := ratio_setting.GetGroupGroupRatio(group, g)
|
||||
if ok {
|
||||
groupRatio[g] = ratio
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func GetPricing(c *gin.Context) {
|
||||
|
||||
usableGroup = setting.GetUserUsableGroups(group)
|
||||
// check groupRatio contains usableGroup
|
||||
for group := range setting.GetGroupRatioCopy() {
|
||||
for group := range ratio_setting.GetGroupRatioCopy() {
|
||||
if _, ok := usableGroup[group]; !ok {
|
||||
delete(groupRatio, group)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func GetPricing(c *gin.Context) {
|
||||
}
|
||||
|
||||
func ResetModelRatio(c *gin.Context) {
|
||||
defaultStr := operation_setting.DefaultModelRatio2JSONString()
|
||||
defaultStr := ratio_setting.DefaultModelRatio2JSONString()
|
||||
err := model.UpdateOption("ModelRatio", defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
@@ -56,7 +56,7 @@ func ResetModelRatio(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
err = operation_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(defaultStr)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
|
||||
24
controller/ratio_config.go
Normal file
24
controller/ratio_config.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
474
controller/ratio_sync.go
Normal file
474
controller/ratio_sync.go
Normal file
@@ -0,0 +1,474 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
)
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
Name string `json:"name"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Err string `json:"err,omitempty"`
|
||||
}
|
||||
|
||||
func FetchUpstreamRatios(c *gin.Context) {
|
||||
var req dto.UpstreamRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Timeout <= 0 {
|
||||
req.Timeout = defaultTimeoutSeconds
|
||||
}
|
||||
|
||||
var upstreams []dto.UpstreamDTO
|
||||
|
||||
if len(req.Upstreams) > 0 {
|
||||
for _, u := range req.Upstreams {
|
||||
if strings.HasPrefix(u.BaseURL, "http") {
|
||||
if u.Endpoint == "" {
|
||||
u.Endpoint = defaultEndpoint
|
||||
}
|
||||
u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
||||
upstreams = append(upstreams, u)
|
||||
}
|
||||
}
|
||||
} else if len(req.ChannelIDs) > 0 {
|
||||
intIds := make([]int, 0, len(req.ChannelIDs))
|
||||
for _, id64 := range req.ChannelIDs {
|
||||
intIds = append(intIds, int(id64))
|
||||
}
|
||||
dbChannels, err := model.GetChannelsByIds(intIds)
|
||||
if err != nil {
|
||||
common.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
||||
return
|
||||
}
|
||||
for _, ch := range dbChannels {
|
||||
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
||||
upstreams = append(upstreams, dto.UpstreamDTO{
|
||||
ID: ch.Id,
|
||||
Name: ch.Name,
|
||||
BaseURL: strings.TrimRight(base, "/"),
|
||||
Endpoint: "",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(upstreams) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan upstreamResult, len(upstreams))
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(chItem dto.UpstreamDTO) {
|
||||
defer wg.Done()
|
||||
|
||||
sem <- struct{}{}
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
common.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
var body struct {
|
||||
Success bool `json:"success"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if !body.Success {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
// 如果包含至少一个 ratioTypes 字段,则认为是 type1
|
||||
isType1 := false
|
||||
for _, rt := range ratioTypes {
|
||||
if _, ok := type1Data[rt]; ok {
|
||||
isType1 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if isType1 {
|
||||
ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不是 type1,则尝试按 type2 (/api/pricing) 解析
|
||||
var pricingItems []struct {
|
||||
ModelName string `json:"model_name"`
|
||||
QuotaType int `json:"quota_type"`
|
||||
ModelRatio float64 `json:"model_ratio"`
|
||||
ModelPrice float64 `json:"model_price"`
|
||||
CompletionRatio float64 `json:"completion_ratio"`
|
||||
}
|
||||
if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
||||
common.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
||||
return
|
||||
}
|
||||
|
||||
modelRatioMap := make(map[string]float64)
|
||||
completionRatioMap := make(map[string]float64)
|
||||
modelPriceMap := make(map[string]float64)
|
||||
|
||||
for _, item := range pricingItems {
|
||||
if item.QuotaType == 1 {
|
||||
modelPriceMap[item.ModelName] = item.ModelPrice
|
||||
} else {
|
||||
modelRatioMap[item.ModelName] = item.ModelRatio
|
||||
// completionRatio 可能为 0,此时也直接赋值,保持与上游一致
|
||||
completionRatioMap[item.ModelName] = item.CompletionRatio
|
||||
}
|
||||
}
|
||||
|
||||
converted := make(map[string]any)
|
||||
|
||||
if len(modelRatioMap) > 0 {
|
||||
ratioAny := make(map[string]any, len(modelRatioMap))
|
||||
for k, v := range modelRatioMap {
|
||||
ratioAny[k] = v
|
||||
}
|
||||
converted["model_ratio"] = ratioAny
|
||||
}
|
||||
|
||||
if len(completionRatioMap) > 0 {
|
||||
compAny := make(map[string]any, len(completionRatioMap))
|
||||
for k, v := range completionRatioMap {
|
||||
compAny[k] = v
|
||||
}
|
||||
converted["completion_ratio"] = compAny
|
||||
}
|
||||
|
||||
if len(modelPriceMap) > 0 {
|
||||
priceAny := make(map[string]any, len(modelPriceMap))
|
||||
for k, v := range modelPriceMap {
|
||||
priceAny[k] = v
|
||||
}
|
||||
converted["model_price"] = priceAny
|
||||
}
|
||||
|
||||
ch <- upstreamResult{Name: uniqueName, Data: converted}
|
||||
}(chn)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
|
||||
localData := ratio_setting.GetExposedData()
|
||||
|
||||
var testResults []dto.TestResult
|
||||
var successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}
|
||||
|
||||
for r := range ch {
|
||||
if r.Err != "" {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "error",
|
||||
Error: r.Err,
|
||||
})
|
||||
} else {
|
||||
testResults = append(testResults, dto.TestResult{
|
||||
Name: r.Name,
|
||||
Status: "success",
|
||||
})
|
||||
successfulChannels = append(successfulChannels, struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}{name: r.Name, data: r.Data})
|
||||
}
|
||||
}
|
||||
|
||||
differences := buildDifferences(localData, successfulChannels)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"differences": differences,
|
||||
"test_results": testResults,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
name string
|
||||
data map[string]any
|
||||
}) map[string]map[string]dto.DifferenceItem {
|
||||
differences := make(map[string]map[string]dto.DifferenceItem)
|
||||
|
||||
allModels := make(map[string]struct{})
|
||||
|
||||
for _, ratioType := range ratioTypes {
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
for modelName := range localRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
for modelName := range upstreamRatio {
|
||||
allModels[modelName] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
confidenceMap := make(map[string]map[string]bool)
|
||||
|
||||
// 预处理阶段:检查pricing接口的可信度
|
||||
for _, channel := range successfulChannels {
|
||||
confidenceMap[channel.name] = make(map[string]bool)
|
||||
|
||||
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
||||
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
||||
|
||||
if hasModelRatio && hasCompletionRatio {
|
||||
// 遍历所有模型,检查是否满足不可信条件
|
||||
for modelName := range allModels {
|
||||
// 默认为可信
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
|
||||
// 检查是否满足不可信条件:model_ratio为37.5且completion_ratio为1
|
||||
if modelRatioVal, ok := modelRatios[modelName]; ok {
|
||||
if completionRatioVal, ok := completionRatios[modelName]; ok {
|
||||
// 转换为float64进行比较
|
||||
if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
||||
if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
||||
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
||||
confidenceMap[channel.name][modelName] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果不是从pricing接口获取的数据,则全部标记为可信
|
||||
for modelName := range allModels {
|
||||
confidenceMap[channel.name][modelName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName := range allModels {
|
||||
for _, ratioType := range ratioTypes {
|
||||
var localValue interface{} = nil
|
||||
if localRatioAny, ok := localData[ratioType]; ok {
|
||||
if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
||||
if val, exists := localRatio[modelName]; exists {
|
||||
localValue = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
upstreamValues := make(map[string]interface{})
|
||||
confidenceValues := make(map[string]bool)
|
||||
hasUpstreamValue := false
|
||||
hasDifference := false
|
||||
|
||||
for _, channel := range successfulChannels {
|
||||
var upstreamValue interface{} = nil
|
||||
|
||||
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
||||
if val, exists := upstreamRatio[modelName]; exists {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
}
|
||||
if upstreamValue == nil && localValue == nil {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
|
||||
if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
||||
hasDifference = true
|
||||
}
|
||||
|
||||
upstreamValues[channel.name] = upstreamValue
|
||||
|
||||
confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
||||
}
|
||||
|
||||
shouldInclude := false
|
||||
|
||||
if localValue != nil {
|
||||
if hasDifference {
|
||||
shouldInclude = true
|
||||
}
|
||||
} else {
|
||||
if hasUpstreamValue {
|
||||
shouldInclude = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldInclude {
|
||||
if differences[modelName] == nil {
|
||||
differences[modelName] = make(map[string]dto.DifferenceItem)
|
||||
}
|
||||
differences[modelName][ratioType] = dto.DifferenceItem{
|
||||
Current: localValue,
|
||||
Upstreams: upstreamValues,
|
||||
Confidence: confidenceValues,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
channelHasDiff := make(map[string]bool)
|
||||
for _, ratioMap := range differences {
|
||||
for _, item := range ratioMap {
|
||||
for chName, val := range item.Upstreams {
|
||||
if val != nil && val != "same" {
|
||||
channelHasDiff[chName] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelName, ratioMap := range differences {
|
||||
for ratioType, item := range ratioMap {
|
||||
for chName := range item.Upstreams {
|
||||
if !channelHasDiff[chName] {
|
||||
delete(item.Upstreams, chName)
|
||||
delete(item.Confidence, chName)
|
||||
}
|
||||
}
|
||||
|
||||
allSame := true
|
||||
for _, v := range item.Upstreams {
|
||||
if v != "same" {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(item.Upstreams) == 0 || allSame {
|
||||
delete(ratioMap, ratioType)
|
||||
} else {
|
||||
differences[modelName][ratioType] = item
|
||||
}
|
||||
}
|
||||
|
||||
if len(ratioMap) == 0 {
|
||||
delete(differences, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
return differences
|
||||
}
|
||||
|
||||
func GetSyncableChannels(c *gin.Context) {
|
||||
channels, err := model.GetAllChannels(0, 0, true, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var syncableChannels []dto.SyncableChannel
|
||||
for _, channel := range channels {
|
||||
if channel.GetBaseURL() != "" {
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: channel.Id,
|
||||
Name: channel.Name,
|
||||
BaseURL: channel.GetBaseURL(),
|
||||
Status: channel.Status,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": syncableChannels,
|
||||
})
|
||||
}
|
||||
@@ -259,7 +259,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
|
||||
AutoBan: &autoBanInt,
|
||||
}, nil
|
||||
}
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func RelayTask(c *gin.Context) {
|
||||
retryTimes = 0
|
||||
}
|
||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||
channel, _, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, i)
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||
break
|
||||
@@ -420,7 +420,7 @@ func RelayTask(c *gin.Context) {
|
||||
func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
|
||||
var err *dto.TaskError
|
||||
switch relayMode {
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
|
||||
case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeKlingFetchByID:
|
||||
err = relay.RelayTaskFetch(c, relayMode)
|
||||
default:
|
||||
err = relay.RelayTaskSubmit(c, relayMode)
|
||||
|
||||
@@ -74,6 +74,8 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
|
||||
case constant.TaskPlatformSuno:
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
case constant.TaskPlatformKling, constant.TaskPlatformJimeng:
|
||||
_ = UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM)
|
||||
default:
|
||||
common.SysLog("未知平台")
|
||||
}
|
||||
|
||||
138
controller/task_video.go
Normal file
138
controller/task_video.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"one-api/relay/channel"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
if err := updateVideoTaskAll(ctx, platform, channelId, taskIds, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoTaskAll(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
cacheGetChannel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
errUpdate := model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if errUpdate != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate))
|
||||
}
|
||||
return fmt.Errorf("CacheGetChannel failed: %w", err)
|
||||
}
|
||||
adaptor := relay.GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("video adaptor not found")
|
||||
}
|
||||
for _, taskId := range taskIds {
|
||||
if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, channel *model.Channel, taskId string, taskM map[string]*model.Task) error {
|
||||
baseURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseURL = channel.GetBaseURL()
|
||||
}
|
||||
|
||||
task := taskM[taskId]
|
||||
if task == nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId))
|
||||
return fmt.Errorf("task %s not found", taskId)
|
||||
}
|
||||
resp, err := adaptor.FetchTask(baseURL, channel.Key, map[string]any{
|
||||
"task_id": taskId,
|
||||
"action": task.Action,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if resp.StatusCode != http.StatusOK {
|
||||
//return fmt.Errorf("get Video Task status code: %d", resp.StatusCode)
|
||||
//}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("readAll failed for task %s: %w", taskId, err)
|
||||
}
|
||||
|
||||
taskResult, err := adaptor.ParseTaskResult(responseBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
}
|
||||
//if taskResult.Code != 0 {
|
||||
// return fmt.Errorf("video task fetch failed for task %s", taskId)
|
||||
//}
|
||||
|
||||
now := time.Now().Unix()
|
||||
if taskResult.Status == "" {
|
||||
return fmt.Errorf("task %s status is empty", taskId)
|
||||
}
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
task.Progress = "10%"
|
||||
case model.TaskStatusQueued:
|
||||
task.Progress = "20%"
|
||||
case model.TaskStatusInProgress:
|
||||
task.Progress = "30%"
|
||||
if task.StartTime == 0 {
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
common.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
common.LogError(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, common.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, taskId)
|
||||
}
|
||||
if taskResult.Progress != "" {
|
||||
task.Progress = taskResult.Progress
|
||||
}
|
||||
|
||||
task.Data = responseBody
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysError("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -258,3 +258,32 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TokenBatch struct {
|
||||
Ids []int `json:"ids"`
|
||||
}
|
||||
|
||||
func DeleteTokenBatch(c *gin.Context) {
|
||||
tokenBatch := TokenBatch{}
|
||||
if err := c.ShouldBindJSON(&tokenBatch); err != nil || len(tokenBatch.Ids) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "参数错误",
|
||||
})
|
||||
return
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
count, err := model.BatchDeleteTokens(tokenBatch.Ids, userId)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": count,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,14 +97,12 @@ func RequestEpay(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"})
|
||||
return
|
||||
}
|
||||
payType := "wxpay"
|
||||
if req.PaymentMethod == "zfb" {
|
||||
payType = "alipay"
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
req.PaymentMethod = "wxpay"
|
||||
payType = "wxpay"
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
@@ -116,7 +114,7 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: payType,
|
||||
Type: req.PaymentMethod,
|
||||
ServiceTradeNo: tradeNo,
|
||||
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||
|
||||
@@ -226,6 +226,9 @@ func Register(c *gin.Context) {
|
||||
UnlimitedQuota: true,
|
||||
ModelLimitsEnabled: false,
|
||||
}
|
||||
if setting.DefaultUseAutoGroup {
|
||||
token.Group = "auto"
|
||||
}
|
||||
if err := token.Insert(); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
|
||||
@@ -15,6 +15,7 @@ type ImageRequest struct {
|
||||
Background string `json:"background,omitempty"`
|
||||
Moderation string `json:"moderation,omitempty"`
|
||||
OutputFormat string `json:"output_format,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
|
||||
@@ -53,9 +53,11 @@ type GeneralOpenAIRequest struct {
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"` // ali
|
||||
THINKING json.RawMessage `json:"thinking,omitempty"` // doubao
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"`
|
||||
// OpenRouter Params
|
||||
Usage json.RawMessage `json:"usage,omitempty"`
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
@@ -644,4 +646,6 @@ type ResponsesToolsCall struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
Function json.RawMessage `json:"function,omitempty"`
|
||||
Container json.RawMessage `json:"container,omitempty"`
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Created any `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Error *OpenAIError `json:"error,omitempty"`
|
||||
Usage `json:"usage"`
|
||||
@@ -178,6 +178,8 @@ type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokensDetails *InputTokenDetails `json:"input_tokens_details"`
|
||||
// OpenRouter Params
|
||||
Cost float64 `json:"cost,omitempty"`
|
||||
}
|
||||
|
||||
type InputTokenDetails struct {
|
||||
|
||||
38
dto/ratio_sync.go
Normal file
38
dto/ratio_sync.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package dto
|
||||
|
||||
type UpstreamDTO struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type UpstreamRequest struct {
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// TestResult 上游测试连通性结果
|
||||
type TestResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// DifferenceItem 差异项
|
||||
// Current 为本地值,可能为 nil
|
||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||
|
||||
type DifferenceItem struct {
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
}
|
||||
|
||||
type SyncableChannel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
47
dto/video.go
Normal file
47
dto/video.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package dto
|
||||
|
||||
type VideoRequest struct {
|
||||
Model string `json:"model,omitempty" example:"kling-v1"` // Model/style ID
|
||||
Prompt string `json:"prompt,omitempty" example:"宇航员站起身走了"` // Text prompt
|
||||
Image string `json:"image,omitempty" example:"https://h2.inkwai.com/bs2/upload-ylab-stunt/se/ai_portal_queue_mmu_image_upscale_aiweb/3214b798-e1b4-4b00-b7af-72b5b0417420_raw_image_0.jpg"` // Image input (URL/Base64)
|
||||
Duration float64 `json:"duration" example:"5.0"` // Video duration (seconds)
|
||||
Width int `json:"width" example:"512"` // Video width
|
||||
Height int `json:"height" example:"512"` // Video height
|
||||
Fps int `json:"fps,omitempty" example:"30"` // Video frame rate
|
||||
Seed int `json:"seed,omitempty" example:"20231234"` // Random seed
|
||||
N int `json:"n,omitempty" example:"1"` // Number of videos to generate
|
||||
ResponseFormat string `json:"response_format,omitempty" example:"url"` // Response format
|
||||
User string `json:"user,omitempty" example:"user-1234"` // User identifier
|
||||
Metadata map[string]any `json:"metadata,omitempty"` // Vendor-specific/custom params (e.g. negative_prompt, style, quality_level, etc.)
|
||||
}
|
||||
|
||||
// VideoResponse 视频生成提交任务后的响应
|
||||
type VideoResponse struct {
|
||||
TaskId string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// VideoTaskResponse 查询视频生成任务状态的响应
|
||||
type VideoTaskResponse struct {
|
||||
TaskId string `json:"task_id" example:"abcd1234efgh"` // 任务ID
|
||||
Status string `json:"status" example:"succeeded"` // 任务状态
|
||||
Url string `json:"url,omitempty"` // 视频资源URL(成功时)
|
||||
Format string `json:"format,omitempty" example:"mp4"` // 视频格式
|
||||
Metadata *VideoTaskMetadata `json:"metadata,omitempty"` // 结果元数据
|
||||
Error *VideoTaskError `json:"error,omitempty"` // 错误信息(失败时)
|
||||
}
|
||||
|
||||
// VideoTaskMetadata 视频任务元数据
|
||||
type VideoTaskMetadata struct {
|
||||
Duration float64 `json:"duration" example:"5.0"` // 实际生成的视频时长
|
||||
Fps int `json:"fps" example:"30"` // 实际帧率
|
||||
Width int `json:"width" example:"512"` // 实际宽度
|
||||
Height int `json:"height" example:"512"` // 实际高度
|
||||
Seed int `json:"seed" example:"20231234"` // 使用的随机种子
|
||||
}
|
||||
|
||||
// VideoTaskError 视频任务错误信息
|
||||
type VideoTaskError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
4
main.go
4
main.go
@@ -12,7 +12,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/router"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
@@ -74,7 +74,7 @@ func main() {
|
||||
}
|
||||
|
||||
// Initialize model settings
|
||||
operation_setting.InitRatioSettings()
|
||||
ratio_setting.InitRatioSettings()
|
||||
// Initialize constants
|
||||
constant.InitEnv()
|
||||
// Initialize options
|
||||
|
||||
@@ -184,7 +184,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
// gemini api 从query中获取key
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
skKey := c.Query("key")
|
||||
if skKey != "" {
|
||||
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -48,13 +49,15 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
// check group in common.GroupRatio
|
||||
if !setting.ContainsGroupRatio(tokenGroup) {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
if !ratio_setting.ContainsGroupRatio(tokenGroup) {
|
||||
if tokenGroup != "auto" {
|
||||
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
|
||||
return
|
||||
}
|
||||
}
|
||||
userGroup = tokenGroup
|
||||
}
|
||||
c.Set("group", userGroup)
|
||||
c.Set(constant.ContextKeyUsingGroup, userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -95,9 +98,14 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
var selectGroup string
|
||||
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
showGroup := userGroup
|
||||
if userGroup == "auto" {
|
||||
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
|
||||
}
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", showGroup, modelRequest.Model)
|
||||
// 如果错误,但是渠道不为空,说明是数据库一致性问题
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
@@ -162,7 +170,26 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
}
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
var platform string
|
||||
var relayMode int
|
||||
if strings.HasPrefix(modelRequest.Model, "jimeng") {
|
||||
platform = string(constant.TaskPlatformJimeng)
|
||||
relayMode = relayconstant.Path2RelayJimeng(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeJimengFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
} else {
|
||||
platform = string(constant.TaskPlatformKling)
|
||||
relayMode = relayconstant.Path2RelayKling(c.Request.Method, c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeKlingFetchByID {
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
c.Set("platform", platform)
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") {
|
||||
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||
relayMode := relayconstant.RelayModeGemini
|
||||
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
|
||||
|
||||
45
middleware/kling_adapter.go
Normal file
45
middleware/kling_adapter.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
func KlingRequestConvert() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
var originalReq map[string]interface{}
|
||||
if err := common.UnmarshalBodyReusable(c, &originalReq); err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
model, _ := originalReq["model"].(string)
|
||||
prompt, _ := originalReq["prompt"].(string)
|
||||
|
||||
unifiedReq := map[string]interface{}{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"metadata": originalReq,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(unifiedReq)
|
||||
if err != nil {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Rewrite request body and path
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
||||
c.Request.URL.Path = "/v1/video/generations"
|
||||
if image := originalReq["image"]; image == "" {
|
||||
c.Set("action", "textGenerate")
|
||||
}
|
||||
|
||||
// We have to reset the request body for the next handlers
|
||||
c.Set(common.KeyRequestBody, jsonData)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -5,10 +5,13 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"one-api/common"
|
||||
"one-api/setting"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var group2model2channels map[string]map[string][]*Channel
|
||||
@@ -75,7 +78,43 @@ func SyncChannelCache(frequency int) {
|
||||
}
|
||||
}
|
||||
|
||||
func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) {
|
||||
var channel *Channel
|
||||
var err error
|
||||
selectGroup := group
|
||||
if group == "auto" {
|
||||
if len(setting.AutoGroups) == 0 {
|
||||
return nil, selectGroup, errors.New("auto groups is not enabled")
|
||||
}
|
||||
for _, autoGroup := range setting.AutoGroups {
|
||||
if common.DebugEnabled {
|
||||
println("autoGroup:", autoGroup)
|
||||
}
|
||||
channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry)
|
||||
if channel == nil {
|
||||
continue
|
||||
} else {
|
||||
c.Set("auto_group", autoGroup)
|
||||
selectGroup = autoGroup
|
||||
if common.DebugEnabled {
|
||||
println("selectGroup:", selectGroup)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
channel, err = getRandomSatisfiedChannel(group, model, retry)
|
||||
if err != nil {
|
||||
return nil, group, err
|
||||
}
|
||||
}
|
||||
if channel == nil {
|
||||
return nil, group, errors.New("channel not found")
|
||||
}
|
||||
return channel, selectGroup, nil
|
||||
}
|
||||
|
||||
func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
|
||||
if strings.HasPrefix(model, "gpt-4-gizmo") {
|
||||
model = "gpt-4-gizmo-*"
|
||||
}
|
||||
|
||||
@@ -597,3 +597,39 @@ func CountAllTags() (int64, error) {
|
||||
err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// Get channels of specified type with pagination
|
||||
func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
// Count channels of specific type
|
||||
func CountChannelsByType(channelType int) (int64, error) {
|
||||
var count int64
|
||||
err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Return map[type]count for all channels
|
||||
func CountChannelsGroupByType() (map[int64]int64, error) {
|
||||
type result struct {
|
||||
Type int64 `gorm:"column:type"`
|
||||
Count int64 `gorm:"column:count"`
|
||||
}
|
||||
var results []result
|
||||
err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts := make(map[int64]int64)
|
||||
for _, r := range results {
|
||||
counts[r.Type] = r.Count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
@@ -46,6 +46,15 @@ func initCol() {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
} else {
|
||||
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
|
||||
if common.UsingPostgreSQL {
|
||||
logGroupCol = `"group"`
|
||||
logKeyCol = `"key"`
|
||||
} else {
|
||||
logGroupCol = commonGroupCol
|
||||
logKeyCol = commonKeyCol
|
||||
}
|
||||
}
|
||||
// log sql type and database type
|
||||
common.SysLog("Using Log SQL Type: " + common.LogSqlType)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -76,6 +77,9 @@ func InitOptionMap() {
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -94,13 +98,13 @@ func InitOptionMap() {
|
||||
common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
|
||||
common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
|
||||
common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
|
||||
common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["ModelRatio"] = ratio_setting.ModelRatio2JSONString()
|
||||
common.OptionMap["ModelPrice"] = ratio_setting.ModelPrice2JSONString()
|
||||
common.OptionMap["CacheRatio"] = ratio_setting.CacheRatio2JSONString()
|
||||
common.OptionMap["GroupRatio"] = ratio_setting.GroupRatio2JSONString()
|
||||
common.OptionMap["GroupGroupRatio"] = ratio_setting.GroupGroupRatio2JSONString()
|
||||
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = operation_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["CompletionRatio"] = ratio_setting.CompletionRatio2JSONString()
|
||||
common.OptionMap["TopUpLink"] = common.TopUpLink
|
||||
//common.OptionMap["ChatLink"] = common.ChatLink
|
||||
//common.OptionMap["ChatLink2"] = common.ChatLink2
|
||||
@@ -123,6 +127,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
|
||||
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
|
||||
common.OptionMap["AutomaticDisableKeywords"] = operation_setting.AutomaticDisableKeywordsToString()
|
||||
common.OptionMap["ExposeRatioEnabled"] = strconv.FormatBool(ratio_setting.IsExposeRatioEnabled())
|
||||
|
||||
// 自动添加所有注册的模型配置
|
||||
modelConfigs := config.GlobalConfig.ExportAllConfigs()
|
||||
@@ -192,7 +197,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ImageDownloadPermission = intValue
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" {
|
||||
if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" || key == "DefaultUseAutoGroup" {
|
||||
boolValue := value == "true"
|
||||
switch key {
|
||||
case "PasswordRegisterEnabled":
|
||||
@@ -261,6 +266,10 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
case "WorkerAllowHttpImageRequestEnabled":
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
case "ExposeRatioEnabled":
|
||||
ratio_setting.SetExposeRatioEnabled(boolValue)
|
||||
}
|
||||
}
|
||||
switch key {
|
||||
@@ -287,6 +296,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "AutoGroups":
|
||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
@@ -352,19 +363,19 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "DataExportDefaultTime":
|
||||
common.DataExportDefaultTime = value
|
||||
case "ModelRatio":
|
||||
err = operation_setting.UpdateModelRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateModelRatioByJSONString(value)
|
||||
case "GroupRatio":
|
||||
err = setting.UpdateGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupRatioByJSONString(value)
|
||||
case "GroupGroupRatio":
|
||||
err = setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateGroupGroupRatioByJSONString(value)
|
||||
case "UserUsableGroups":
|
||||
err = setting.UpdateUserUsableGroupsByJSONString(value)
|
||||
case "CompletionRatio":
|
||||
err = operation_setting.UpdateCompletionRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCompletionRatioByJSONString(value)
|
||||
case "ModelPrice":
|
||||
err = operation_setting.UpdateModelPriceByJSONString(value)
|
||||
err = ratio_setting.UpdateModelPriceByJSONString(value)
|
||||
case "CacheRatio":
|
||||
err = operation_setting.UpdateCacheRatioByJSONString(value)
|
||||
err = ratio_setting.UpdateCacheRatioByJSONString(value)
|
||||
case "TopUpLink":
|
||||
common.TopUpLink = value
|
||||
//case "ChatLink":
|
||||
@@ -381,6 +392,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
operation_setting.AutomaticDisableKeywordsFromString(value)
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -65,14 +65,14 @@ func updatePricing() {
|
||||
ModelName: model,
|
||||
EnableGroup: groups,
|
||||
}
|
||||
modelPrice, findPrice := operation_setting.GetModelPrice(model, false)
|
||||
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
|
||||
if findPrice {
|
||||
pricing.ModelPrice = modelPrice
|
||||
pricing.QuotaType = 1
|
||||
} else {
|
||||
modelRatio, _ := operation_setting.GetModelRatio(model)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(model)
|
||||
pricing.ModelRatio = modelRatio
|
||||
pricing.CompletionRatio = operation_setting.GetCompletionRatio(model)
|
||||
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
|
||||
pricing.QuotaType = 0
|
||||
}
|
||||
pricingMap = append(pricingMap, pricing)
|
||||
|
||||
@@ -327,3 +327,37 @@ func CountUserTokens(userId int) (int64, error) {
|
||||
err := DB.Model(&Token{}).Where("user_id = ?", userId).Count(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// BatchDeleteTokens 删除指定用户的一组令牌,返回成功删除数量
|
||||
func BatchDeleteTokens(ids []int, userId int) (int, error) {
|
||||
if len(ids) == 0 {
|
||||
return 0, errors.New("ids 不能为空!")
|
||||
}
|
||||
|
||||
tx := DB.Begin()
|
||||
|
||||
var tokens []Token
|
||||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Find(&tokens).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := tx.Where("user_id = ? AND id IN (?)", userId, ids).Delete(&Token{}).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if common.RedisEnabled {
|
||||
gopool.Go(func() {
|
||||
for _, t := range tokens {
|
||||
_ = cacheDeleteToken(t.Key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return len(tokens), nil
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -48,6 +49,22 @@ func addNewRecord(type_ int, id int, value int) {
|
||||
}
|
||||
|
||||
func batchUpdate() {
|
||||
// check if there's any data to update
|
||||
hasData := false
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
if len(batchUpdateStores[i]) > 0 {
|
||||
hasData = true
|
||||
batchUpdateLocks[i].Unlock()
|
||||
break
|
||||
}
|
||||
batchUpdateLocks[i].Unlock()
|
||||
}
|
||||
|
||||
if !hasData {
|
||||
return
|
||||
}
|
||||
|
||||
common.SysLog("batch update started")
|
||||
for i := 0; i < BatchUpdateTypeCount; i++ {
|
||||
batchUpdateLocks[i].Lock()
|
||||
|
||||
@@ -55,7 +55,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
}
|
||||
|
||||
func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
|
||||
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
|
||||
|
||||
if err != nil {
|
||||
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
promptTokens := 0
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||
}
|
||||
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||
preConsumedTokens = promptTokens
|
||||
relayInfo.PromptTokens = promptTokens
|
||||
}
|
||||
@@ -89,13 +86,11 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
audioRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||
if adaptor == nil {
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||
@@ -44,4 +44,6 @@ type TaskAdaptor interface {
|
||||
|
||||
// FetchTask
|
||||
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
|
||||
|
||||
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/openrouter"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -122,6 +123,21 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
||||
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
|
||||
}
|
||||
|
||||
if textRequest.Reasoning != nil {
|
||||
var reasoning openrouter.RequestReasoning
|
||||
if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
budgetTokens := reasoning.MaxTokens
|
||||
if budgetTokens > 0 {
|
||||
claudeRequest.Thinking = &dto.Thinking{
|
||||
Type: "enabled",
|
||||
BudgetTokens: &budgetTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if textRequest.Stop != nil {
|
||||
// stop maybe string/array string, convert to array string
|
||||
switch textRequest.Stop.(type) {
|
||||
@@ -454,6 +470,7 @@ type ClaudeResponseInfo struct {
|
||||
Model string
|
||||
ResponseText strings.Builder
|
||||
Usage *dto.Usage
|
||||
Done bool
|
||||
}
|
||||
|
||||
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
|
||||
@@ -461,20 +478,32 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
claudeInfo.ResponseId = claudeResponse.Message.Id
|
||||
claudeInfo.Model = claudeResponse.Message.Model
|
||||
|
||||
// message_start, 获取usage
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
if claudeResponse.Delta.Text != nil {
|
||||
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
|
||||
}
|
||||
if claudeResponse.Delta.Thinking != "" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
|
||||
}
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
// 最终的usage获取
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
|
||||
// 判断是否完整
|
||||
claudeInfo.Done = true
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
} else {
|
||||
return false
|
||||
@@ -506,25 +535,15 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
|
||||
} else {
|
||||
if claudeResponse.Type == "message_start" {
|
||||
// message_start, 获取usage
|
||||
info.UpstreamModelName = claudeResponse.Message.Model
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
|
||||
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
|
||||
} else if claudeResponse.Type == "content_block_delta" {
|
||||
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
|
||||
} else if claudeResponse.Type == "message_delta" {
|
||||
if claudeResponse.Usage.InputTokens > 0 {
|
||||
// 不叠加,只取最新的
|
||||
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
|
||||
}
|
||||
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
|
||||
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
helper.ClaudeChunkData(c, claudeResponse, data)
|
||||
@@ -544,29 +563,25 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
|
||||
if common.DebugEnabled {
|
||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.RelayFormat == relaycommon.RelayFormatClaude {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 说明流模式建立失败,可能为官方出错
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//usage.PromptTokens = info.PromptTokens
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
//
|
||||
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
}
|
||||
if claudeInfo.Usage.CompletionTokens == 0 {
|
||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -619,10 +634,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
}
|
||||
}
|
||||
if requestMode == RequestModeCompletion {
|
||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
||||
}
|
||||
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||
|
||||
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
||||
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
return nil, usage
|
||||
|
||||
@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
var currentEvent string
|
||||
var currentData string
|
||||
var usage dto.Usage
|
||||
var usage = &dto.Usage{}
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
if line == "" {
|
||||
if currentEvent != "" && currentData != "" {
|
||||
// handle last event
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
currentEvent = ""
|
||||
currentData = ""
|
||||
}
|
||||
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
|
||||
// Last event
|
||||
if currentEvent != "" && currentData != "" {
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
||||
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
||||
helper.Done(c)
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||
|
||||
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
||||
return true
|
||||
})
|
||||
helper.Done(c)
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("close_response_body_failed: " + err.Error())
|
||||
}
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.PromptTokens = info.PromptTokens
|
||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return nil, usage
|
||||
|
||||
@@ -73,12 +73,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.OriginModelName, "-thinking-") {
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 旧的适配
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,6 +140,7 @@ type GeminiChatGenerationConfig struct {
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -35,23 +36,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 检查是否有候选响应
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return nil, &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
@@ -88,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||
@@ -102,13 +92,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
@@ -121,7 +114,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.ObjectData(c, geminiResponse)
|
||||
err = helper.StringData(c, data)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
@@ -135,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
||||
}
|
||||
}
|
||||
|
||||
// 计算最终使用量
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||
//helper.Done(c)
|
||||
|
||||
@@ -39,11 +39,100 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
|
||||
// Gemini 允许的思考预算范围
|
||||
const (
|
||||
pro25MinBudget = 128
|
||||
pro25MaxBudget = 32768
|
||||
flash25MaxBudget = 24576
|
||||
pro25MinBudget = 128
|
||||
pro25MaxBudget = 32768
|
||||
flash25MaxBudget = 24576
|
||||
flash25LiteMinBudget = 512
|
||||
flash25LiteMaxBudget = 24576
|
||||
)
|
||||
|
||||
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
|
||||
func clampThinkingBudget(modelName string, budget int) int {
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
|
||||
|
||||
if is25FlashLite {
|
||||
if budget < flash25LiteMinBudget {
|
||||
return flash25LiteMinBudget
|
||||
}
|
||||
if budget > flash25LiteMaxBudget {
|
||||
return flash25LiteMaxBudget
|
||||
}
|
||||
} else if isNew25Pro {
|
||||
if budget < pro25MinBudget {
|
||||
return pro25MinBudget
|
||||
}
|
||||
if budget > pro25MaxBudget {
|
||||
return pro25MaxBudget
|
||||
}
|
||||
} else { // 其他模型
|
||||
if budget < 0 {
|
||||
return 0
|
||||
}
|
||||
if budget > flash25MaxBudget {
|
||||
return flash25MaxBudget
|
||||
}
|
||||
}
|
||||
return budget
|
||||
}
|
||||
|
||||
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
modelName := info.UpstreamModelName
|
||||
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if strings.Contains(modelName, "-thinking-") {
|
||||
parts := strings.SplitN(modelName, "-thinking-", 2)
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(clampedBudget),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-thinking") {
|
||||
unsupportedModels := []string{
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
}
|
||||
isUnsupported := false
|
||||
for _, unsupportedModel := range unsupportedModels {
|
||||
if strings.HasPrefix(modelName, unsupportedModel) {
|
||||
isUnsupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
|
||||
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "-nothinking") {
|
||||
if !isNew25Pro {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
|
||||
|
||||
@@ -64,100 +153,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
}
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.OriginModelName, "-thinking-") {
|
||||
parts := strings.SplitN(info.OriginModelName, "-thinking-", 2)
|
||||
if len(parts) == 2 && parts[1] != "" {
|
||||
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
|
||||
// 从模型名称成功解析预算
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if isNew25Pro {
|
||||
// 新的2.5pro模型:ThinkingBudget范围为128-32768
|
||||
if budgetTokens < pro25MinBudget {
|
||||
budgetTokens = pro25MinBudget
|
||||
} else if budgetTokens > pro25MaxBudget {
|
||||
budgetTokens = pro25MaxBudget
|
||||
}
|
||||
} else {
|
||||
// 其他模型:ThinkingBudget范围为0-24576
|
||||
if budgetTokens < 0 {
|
||||
budgetTokens = 0
|
||||
} else if budgetTokens > flash25MaxBudget {
|
||||
budgetTokens = flash25MaxBudget
|
||||
}
|
||||
}
|
||||
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(budgetTokens),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
// 如果解析失败,则不设置ThinkingConfig,静默处理
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-thinking") { // 保留旧逻辑以兼容
|
||||
// 硬编码不支持 ThinkingBudget 的旧模型
|
||||
unsupportedModels := []string{
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-pro-preview-03-25",
|
||||
}
|
||||
|
||||
isUnsupported := false
|
||||
for _, unsupportedModel := range unsupportedModels {
|
||||
if strings.HasPrefix(info.OriginModelName, unsupportedModel) {
|
||||
isUnsupported = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isUnsupported {
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
} else {
|
||||
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
|
||||
|
||||
// 检查是否为新的2.5pro模型(支持ThinkingBudget但有特殊范围)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if isNew25Pro {
|
||||
// 新的2.5pro模型:ThinkingBudget范围为128-32768
|
||||
if budgetTokens == 0 || budgetTokens < 128 {
|
||||
budgetTokens = 128
|
||||
} else if budgetTokens > 32768 {
|
||||
budgetTokens = 32768
|
||||
}
|
||||
} else {
|
||||
// 其他模型:ThinkingBudget范围为0-24576
|
||||
if budgetTokens == 0 || budgetTokens > 24576 {
|
||||
budgetTokens = 24576
|
||||
}
|
||||
}
|
||||
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(int(budgetTokens)),
|
||||
IncludeThoughts: true,
|
||||
}
|
||||
}
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
// 检查是否为新的2.5pro模型(不支持-nothinking,因为最低值只能为128)
|
||||
isNew25Pro := strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-05-06") &&
|
||||
!strings.HasPrefix(info.OriginModelName, "gemini-2.5-pro-preview-03-25")
|
||||
|
||||
if !isNew25Pro {
|
||||
// 只有非新2.5pro模型才支持-nothinking
|
||||
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
|
||||
ThinkingBudget: common.GetPointer(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ThinkingAdaptor(&geminiRequest, info)
|
||||
|
||||
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
|
||||
for _, category := range SafetySettingList {
|
||||
@@ -324,7 +320,8 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
|
||||
// 校验 MimeType 是否在 Gemini 支持的白名单中
|
||||
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
|
||||
return nil, fmt.Errorf("MIME type '%s' from URL '%s' is not supported by Gemini. Supported types are: %v", fileData.MimeType, part.GetImageMedia().Url, getSupportedMimeTypesList())
|
||||
url := part.GetImageMedia().Url
|
||||
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiPart{
|
||||
@@ -382,7 +379,9 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
if len(content.Parts) > 0 {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(system_content) > 0 {
|
||||
|
||||
@@ -159,6 +159,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
|
||||
request.StreamOptions = nil
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenRouter {
|
||||
if len(request.Usage) == 0 {
|
||||
request.Usage = json.RawMessage(`{"include":true}`)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(request.Model, "o") {
|
||||
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
||||
request.MaxCompletionTokens = request.MaxTokens
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
@@ -16,6 +15,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
} else {
|
||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||
@@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
forceFormat := false
|
||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||
forceFormat = forceFmt
|
||||
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||
completionTokens := 0
|
||||
for _, choice := range simpleResponse.Choices {
|
||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||
completionTokens += ctkm
|
||||
}
|
||||
simpleResponse.Usage = dto.Usage{
|
||||
@@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
// the status code has been judged before, if there is a body reading failure,
|
||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||
// and can be terminated directly.
|
||||
defer resp.Body.Close()
|
||||
usage := &dto.Usage{}
|
||||
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
||||
if err = c.ShouldBind(&reqBody); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||
reqFp, err := reqBody.File.Open()
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
defer reqFp.Close()
|
||||
defer reqFp.Close()
|
||||
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
@@ -623,13 +623,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
common.SysError("error copying response body: " + err.Error())
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// Once we've written to the client, we should not return errors anymore
|
||||
// because the upstream has already consumed resources and returned content
|
||||
// We should still perform billing even if parsing fails
|
||||
var usageResp dto.SimpleResponse
|
||||
err = json.Unmarshal(responseBody, &usageResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
||||
tempStr := responseTextBuilder.String()
|
||||
if len(tempStr) > 0 {
|
||||
// 非正常结束,使用输出文本的 token 数量
|
||||
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||
usage.CompletionTokens = completionTokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||
usage := dto.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
|
||||
379
relay/channel/task/jimeng/adaptor.go
Normal file
379
relay/channel/task/jimeng/adaptor.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package jimeng
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/model"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type requestPayload struct {
|
||||
ReqKey string `json:"req_key"`
|
||||
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
|
||||
ImageUrls []string `json:"image_urls,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Seed int64 `json:"seed"`
|
||||
AspectRatio string `json:"aspect_ratio"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Data struct {
|
||||
TaskID string `json:"task_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type responseTask struct {
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
BinaryDataBase64 []interface{} `json:"binary_data_base64"`
|
||||
ImageUrls interface{} `json:"image_urls"`
|
||||
RespData string `json:"resp_data"`
|
||||
Status string `json:"status"`
|
||||
VideoUrl string `json:"video_url"`
|
||||
} `json:"data"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Status int `json:"status"`
|
||||
TimeElapsed string `json:"time_elapsed"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
|
||||
// apiKey format: "access_key,secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, ",")
|
||||
if len(keyParts) == 2 {
|
||||
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
return a.signRequest(req, a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Jimeng specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "convert request payload failed")
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
// Parse Jimeng response
|
||||
var jResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &jResp); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if jResp.Code != 10000 {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
|
||||
return jResp.Data.TaskID, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
|
||||
payload := map[string]string{
|
||||
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
|
||||
"task_id": taskID,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal fetch task payload failed")
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
keyParts := strings.Split(key, ",")
|
||||
if len(keyParts) != 2 {
|
||||
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak,sk'")
|
||||
}
|
||||
accessKey := strings.TrimSpace(keyParts[0])
|
||||
secretKey := strings.TrimSpace(keyParts[1])
|
||||
|
||||
if err := a.signRequest(req, accessKey, secretKey); err != nil {
|
||||
return nil, errors.Wrap(err, "sign request failed")
|
||||
}
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"jimeng_vgfm_t2v_l20"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "jimeng"
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
|
||||
var bodyBytes []byte
|
||||
var err error
|
||||
|
||||
if req.Body != nil {
|
||||
bodyBytes, err = io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read request body failed")
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
|
||||
} else {
|
||||
bodyBytes = []byte{}
|
||||
}
|
||||
|
||||
payloadHash := sha256.Sum256(bodyBytes)
|
||||
hexPayloadHash := hex.EncodeToString(payloadHash[:])
|
||||
|
||||
t := time.Now().UTC()
|
||||
xDate := t.Format("20060102T150405Z")
|
||||
shortDate := t.Format("20060102")
|
||||
|
||||
req.Header.Set("Host", req.URL.Host)
|
||||
req.Header.Set("X-Date", xDate)
|
||||
req.Header.Set("X-Content-Sha256", hexPayloadHash)
|
||||
|
||||
// Sort and encode query parameters to create canonical query string
|
||||
queryParams := req.URL.Query()
|
||||
sortedKeys := make([]string, 0, len(queryParams))
|
||||
for k := range queryParams {
|
||||
sortedKeys = append(sortedKeys, k)
|
||||
}
|
||||
sort.Strings(sortedKeys)
|
||||
var queryParts []string
|
||||
for _, k := range sortedKeys {
|
||||
values := queryParams[k]
|
||||
sort.Strings(values)
|
||||
for _, v := range values {
|
||||
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
|
||||
}
|
||||
}
|
||||
canonicalQueryString := strings.Join(queryParts, "&")
|
||||
|
||||
headersToSign := map[string]string{
|
||||
"host": req.URL.Host,
|
||||
"x-date": xDate,
|
||||
"x-content-sha256": hexPayloadHash,
|
||||
}
|
||||
if req.Header.Get("Content-Type") != "" {
|
||||
headersToSign["content-type"] = req.Header.Get("Content-Type")
|
||||
}
|
||||
|
||||
var signedHeaderKeys []string
|
||||
for k := range headersToSign {
|
||||
signedHeaderKeys = append(signedHeaderKeys, k)
|
||||
}
|
||||
sort.Strings(signedHeaderKeys)
|
||||
|
||||
var canonicalHeaders strings.Builder
|
||||
for _, k := range signedHeaderKeys {
|
||||
canonicalHeaders.WriteString(k)
|
||||
canonicalHeaders.WriteString(":")
|
||||
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
|
||||
canonicalHeaders.WriteString("\n")
|
||||
}
|
||||
signedHeaders := strings.Join(signedHeaderKeys, ";")
|
||||
|
||||
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
|
||||
req.Method,
|
||||
req.URL.Path,
|
||||
canonicalQueryString,
|
||||
canonicalHeaders.String(),
|
||||
signedHeaders,
|
||||
hexPayloadHash,
|
||||
)
|
||||
|
||||
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
|
||||
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
|
||||
|
||||
region := "cn-north-1"
|
||||
serviceName := "cv"
|
||||
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
|
||||
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
|
||||
xDate,
|
||||
credentialScope,
|
||||
hexHashedCanonicalRequest,
|
||||
)
|
||||
|
||||
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
|
||||
kRegion := hmacSHA256(kDate, []byte(region))
|
||||
kService := hmacSHA256(kRegion, []byte(serviceName))
|
||||
kSigning := hmacSHA256(kService, []byte("request"))
|
||||
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
|
||||
|
||||
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
|
||||
accessKey,
|
||||
credentialScope,
|
||||
signedHeaders,
|
||||
signature,
|
||||
)
|
||||
req.Header.Set("Authorization", authorization)
|
||||
return nil
|
||||
}
|
||||
|
||||
func hmacSHA256(key []byte, data []byte) []byte {
|
||||
h := hmac.New(sha256.New, key)
|
||||
h.Write(data)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
ReqKey: "jimeng_vgfm_i2v_l20",
|
||||
Prompt: req.Prompt,
|
||||
AspectRatio: "16:9", // Default aspect ratio
|
||||
Seed: -1, // Default to random
|
||||
}
|
||||
|
||||
// Handle one-of image_urls or binary_data_base64
|
||||
if req.Image != "" {
|
||||
if strings.HasPrefix(req.Image, "http") {
|
||||
r.ImageUrls = []string{req.Image}
|
||||
} else {
|
||||
r.BinaryDataBase64 = []string{req.Image}
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resTask := responseTask{}
|
||||
if err := json.Unmarshal(respBody, &resTask); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal task result failed")
|
||||
}
|
||||
taskResult := relaycommon.TaskInfo{}
|
||||
if resTask.Code == 10000 {
|
||||
taskResult.Code = 0
|
||||
} else {
|
||||
taskResult.Code = resTask.Code // todo uni code
|
||||
taskResult.Reason = resTask.Message
|
||||
taskResult.Status = model.TaskStatusFailure
|
||||
taskResult.Progress = "100%"
|
||||
}
|
||||
switch resTask.Data.Status {
|
||||
case "in_queue":
|
||||
taskResult.Status = model.TaskStatusQueued
|
||||
taskResult.Progress = "10%"
|
||||
case "done":
|
||||
taskResult.Status = model.TaskStatusSuccess
|
||||
taskResult.Progress = "100%"
|
||||
}
|
||||
taskResult.Url = resTask.Data.VideoUrl
|
||||
return &taskResult, nil
|
||||
}
|
||||
345
relay/channel/task/kling/adaptor.go
Normal file
345
relay/channel/task/kling/adaptor.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package kling
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Duration string `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
ModelName string `json:"model_name,omitempty"`
|
||||
CfgScale float64 `json:"cfg_scale,omitempty"`
|
||||
}
|
||||
|
||||
type responsePayload struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
Data struct {
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
TaskStatusMsg string `json:"task_status_msg"`
|
||||
TaskResult struct {
|
||||
Videos []struct {
|
||||
Id string `json:"id"`
|
||||
Url string `json:"url"`
|
||||
Duration string `json:"duration"`
|
||||
} `json:"videos"`
|
||||
} `json:"task_result"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
accessKey string
|
||||
secretKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.BaseUrl
|
||||
|
||||
// apiKey format: "access_key|secret_key"
|
||||
keyParts := strings.Split(info.ApiKey, "|")
|
||||
if len(keyParts) == 2 {
|
||||
a.accessKey = strings.TrimSpace(keyParts[0])
|
||||
a.secretKey = strings.TrimSpace(keyParts[1])
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := "generate"
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
|
||||
path := lo.Ternary(info.Action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
return fmt.Sprintf("%s%s", a.baseURL, path), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
|
||||
token, err := a.createJWTToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create JWT token: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Kling specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
|
||||
v, exists := c.Get("task_request")
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
if action := c.GetString("action"); action != "" {
|
||||
info.Action = action
|
||||
}
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt Kling response parse first.
|
||||
var kResp responsePayload
|
||||
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
|
||||
return kResp.Data.TaskId, responseBody, nil
|
||||
}
|
||||
|
||||
// Fallback generic task response.
|
||||
var generic dto.TaskResponse[string]
|
||||
if err := json.Unmarshal(responseBody, &generic); err != nil {
|
||||
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !generic.IsSuccess() {
|
||||
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
|
||||
return generic.Data, responseBody, nil
|
||||
}
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
action, ok := body["action"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid action")
|
||||
}
|
||||
path := lo.Ternary(action == "generate", "/v1/videos/image2video", "/v1/videos/text2video")
|
||||
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := a.createJWTTokenWithKey(key)
|
||||
if err != nil {
|
||||
token = key
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", "kling-sdk/1.0")
|
||||
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string {
|
||||
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetChannelName() string {
|
||||
return "kling"
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
Mode: defaultString(req.Mode, "std"),
|
||||
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
|
||||
AspectRatio: a.getAspectRatio(req.Size),
|
||||
ModelName: req.Model,
|
||||
CfgScale: 0.5,
|
||||
}
|
||||
if r.ModelName == "" {
|
||||
r.ModelName = "kling-v1"
|
||||
}
|
||||
metadata := req.Metadata
|
||||
medaBytes, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "metadata marshal metadata failed")
|
||||
}
|
||||
err = json.Unmarshal(medaBytes, &r)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshal metadata failed")
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) getAspectRatio(size string) string {
|
||||
switch size {
|
||||
case "1024x1024", "512x512":
|
||||
return "1:1"
|
||||
case "1280x720", "1920x1080":
|
||||
return "16:9"
|
||||
case "720x1280", "1080x1920":
|
||||
return "9:16"
|
||||
default:
|
||||
return "1:1"
|
||||
}
|
||||
}
|
||||
|
||||
func defaultString(s, def string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func defaultInt(v int, def int) int {
|
||||
if v == 0 {
|
||||
return def
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// ============================
|
||||
// JWT helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) createJWTToken() (string, error) {
|
||||
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
|
||||
}
|
||||
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
|
||||
if accessKey == "" || secretKey == "" {
|
||||
return "", fmt.Errorf("access key and secret key are required")
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": accessKey,
|
||||
"exp": now + 1800, // 30 minutes
|
||||
"nbf": now - 5,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token.Header["typ"] = "JWT"
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
resPayload := responsePayload{}
|
||||
err := json.Unmarshal(respBody, &resPayload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal response body")
|
||||
}
|
||||
taskInfo := &relaycommon.TaskInfo{}
|
||||
taskInfo.Code = resPayload.Code
|
||||
taskInfo.TaskID = resPayload.Data.TaskId
|
||||
taskInfo.Reason = resPayload.Message
|
||||
//任务状态,枚举值:submitted(已提交)、processing(处理中)、succeed(成功)、failed(失败)
|
||||
status := resPayload.Data.TaskStatus
|
||||
switch status {
|
||||
case "submitted":
|
||||
taskInfo.Status = model.TaskStatusSubmitted
|
||||
case "processing":
|
||||
taskInfo.Status = model.TaskStatusInProgress
|
||||
case "succeed":
|
||||
taskInfo.Status = model.TaskStatusSuccess
|
||||
case "failed":
|
||||
taskInfo.Status = model.TaskStatusFailure
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown task status: %s", status)
|
||||
}
|
||||
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
|
||||
video := videos[0]
|
||||
taskInfo.Url = video.Url
|
||||
}
|
||||
return taskInfo, nil
|
||||
}
|
||||
@@ -22,6 +22,10 @@ type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
|
||||
return nil, fmt.Errorf("not implement") // todo implement this method if needed
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = tencentStreamHandler(c, resp)
|
||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
err, usage = tencentHandler(c, resp)
|
||||
}
|
||||
|
||||
@@ -83,10 +83,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// suffix -thinking and -nothinking
|
||||
if strings.HasSuffix(info.OriginModelName, "-thinking") {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
parts := strings.Split(info.UpstreamModelName, "-thinking-")
|
||||
info.UpstreamModelName = parts[0]
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
|
||||
} else if strings.HasSuffix(info.OriginModelName, "-nothinking") {
|
||||
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
|
||||
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
|
||||
}
|
||||
}
|
||||
@@ -123,14 +126,23 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
package volcengine
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/openai"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -30,8 +34,146 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
//TODO implement me
|
||||
return nil, errors.New("not implemented")
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesEdits:
|
||||
|
||||
var requestBody bytes.Buffer
|
||||
writer := multipart.NewWriter(&requestBody)
|
||||
|
||||
writer.WriteField("model", request.Model)
|
||||
// 获取所有表单字段
|
||||
formData := c.Request.PostForm
|
||||
// 遍历表单字段并打印输出
|
||||
for key, values := range formData {
|
||||
if key == "model" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
writer.WriteField(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the multipart form to handle both single image and multiple images
|
||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
|
||||
return nil, errors.New("failed to parse multipart form")
|
||||
}
|
||||
|
||||
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
||||
// Check if "image" field exists in any form, including array notation
|
||||
var imageFiles []*multipart.FileHeader
|
||||
var exists bool
|
||||
|
||||
// First check for standard "image" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
||||
// If not found, check for "image[]" field
|
||||
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
||||
// If still not found, iterate through all fields to find any that start with "image["
|
||||
foundArrayImages := false
|
||||
for fieldName, files := range c.Request.MultipartForm.File {
|
||||
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
||||
foundArrayImages = true
|
||||
for _, file := range files {
|
||||
imageFiles = append(imageFiles, file)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no image fields found at all
|
||||
if !foundArrayImages && (len(imageFiles) == 0) {
|
||||
return nil, errors.New("image is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process all image files
|
||||
for i, fileHeader := range imageFiles {
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// If multiple images, use image[] as the field name
|
||||
fieldName := "image"
|
||||
if len(imageFiles) > 1 {
|
||||
fieldName = "image[]"
|
||||
}
|
||||
|
||||
// Determine MIME type based on file extension
|
||||
mimeType := detectImageMimeType(fileHeader.Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
|
||||
part, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle mask file if present
|
||||
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
||||
maskFile, err := maskFiles[0].Open()
|
||||
if err != nil {
|
||||
return nil, errors.New("failed to open mask file")
|
||||
}
|
||||
defer maskFile.Close()
|
||||
|
||||
// Determine MIME type for mask file
|
||||
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
||||
|
||||
// Create a form file with the appropriate content type
|
||||
h := make(textproto.MIMEHeader)
|
||||
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
||||
h.Set("Content-Type", mimeType)
|
||||
|
||||
maskPart, err := writer.CreatePart(h)
|
||||
if err != nil {
|
||||
return nil, errors.New("create form file failed for mask")
|
||||
}
|
||||
|
||||
if _, err := io.Copy(maskPart, maskFile); err != nil {
|
||||
return nil, errors.New("copy mask file failed")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("no multipart form data found")
|
||||
}
|
||||
|
||||
// 关闭 multipart 编写器以设置分界线
|
||||
writer.Close()
|
||||
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return bytes.NewReader(requestBody.Bytes()), nil
|
||||
|
||||
default:
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
// detectImageMimeType determines the MIME type based on the file extension
|
||||
func detectImageMimeType(filename string) string {
|
||||
ext := strings.ToLower(filepath.Ext(filename))
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
default:
|
||||
// Try to detect from extension if possible
|
||||
if strings.HasPrefix(ext, ".jp") {
|
||||
return "image/jpeg"
|
||||
}
|
||||
// Default to png as a fallback
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -46,6 +188,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
||||
case constant.RelayModeImagesGenerations:
|
||||
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
||||
default:
|
||||
}
|
||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||
@@ -91,6 +235,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
}
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
||||
err, usage = openai.OpenaiHandlerWithUsage(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package xinference
|
||||
|
||||
type XinRerankResponseDocument struct {
|
||||
Document string `json:"document,omitempty"`
|
||||
Document any `json:"document,omitempty"`
|
||||
Index int `json:"index"`
|
||||
RelevanceScore float64 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
@@ -46,13 +46,11 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
relayInfo.IsStream = true
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
textRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptTokens, err := getClaudePromptTokens(textRequest, relayInfo)
|
||||
// count messages token error 计算promptTokens错误
|
||||
if err != nil {
|
||||
@@ -126,7 +124,7 @@ func ClaudeHelper(c *gin.Context) (claudeError *dto.ClaudeErrorWithStatusCode) {
|
||||
var httpResp *http.Response
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||
if err != nil {
|
||||
return service.ClaudeErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return service.ClaudeErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
|
||||
@@ -34,9 +34,14 @@ type ClaudeConvertInfo struct {
|
||||
}
|
||||
|
||||
const (
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
RelayFormatGemini = "gemini"
|
||||
RelayFormatOpenAI = "openai"
|
||||
RelayFormatClaude = "claude"
|
||||
RelayFormatGemini = "gemini"
|
||||
RelayFormatOpenAIResponses = "openai_responses"
|
||||
RelayFormatOpenAIAudio = "openai_audio"
|
||||
RelayFormatOpenAIImage = "openai_image"
|
||||
RelayFormatRerank = "rerank"
|
||||
RelayFormatEmbedding = "embedding"
|
||||
)
|
||||
|
||||
type RerankerInfo struct {
|
||||
@@ -60,8 +65,8 @@ type RelayInfo struct {
|
||||
TokenId int
|
||||
TokenKey string
|
||||
UserId int
|
||||
Group string
|
||||
UserGroup string
|
||||
UsingGroup string // 使用的分组
|
||||
UserGroup string // 用户所在分组
|
||||
TokenUnlimited bool
|
||||
StartTime time.Time
|
||||
FirstResponseTime time.Time
|
||||
@@ -143,6 +148,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
|
||||
func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayMode = relayconstant.RelayModeRerank
|
||||
info.RelayFormat = RelayFormatRerank
|
||||
info.RerankerInfo = &RerankerInfo{
|
||||
Documents: req.Documents,
|
||||
ReturnDocuments: req.GetReturnDocuments(),
|
||||
@@ -150,9 +156,25 @@ func GenRelayInfoRerank(c *gin.Context, req *dto.RerankRequest) *RelayInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoOpenAIAudio(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIAudio
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoEmbedding(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatEmbedding
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayMode = relayconstant.RelayModeResponses
|
||||
info.RelayFormat = RelayFormatOpenAIResponses
|
||||
|
||||
info.SupportStreamOptions = false
|
||||
|
||||
info.ResponsesUsageInfo = &ResponsesUsageInfo{
|
||||
BuiltInTools: make(map[string]*BuildInToolInfo),
|
||||
}
|
||||
@@ -175,6 +197,19 @@ func GenRelayInfoResponses(c *gin.Context, req *dto.OpenAIResponsesRequest) *Rel
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoGemini(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatGemini
|
||||
info.ShouldIncludeUsage = false
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfoImage(c *gin.Context) *RelayInfo {
|
||||
info := GenRelayInfo(c)
|
||||
info.RelayFormat = RelayFormatOpenAIImage
|
||||
return info
|
||||
}
|
||||
|
||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
channelType := c.GetInt("channel_type")
|
||||
channelId := c.GetInt("channel_id")
|
||||
@@ -184,7 +219,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
tokenId := c.GetInt("token_id")
|
||||
tokenKey := c.GetString("token_key")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenUnlimited := c.GetBool("token_unlimited_quota")
|
||||
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
|
||||
// firstResponseTime = time.Now() - 1 second
|
||||
@@ -204,7 +238,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
TokenId: tokenId,
|
||||
TokenKey: tokenKey,
|
||||
UserId: userId,
|
||||
Group: group,
|
||||
UsingGroup: c.GetString(constant.ContextKeyUsingGroup),
|
||||
UserGroup: c.GetString(constant.ContextKeyUserGroup),
|
||||
TokenUnlimited: tokenUnlimited,
|
||||
StartTime: startTime,
|
||||
@@ -243,10 +277,6 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
if streamSupportedChannels[info.ChannelType] {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
// responses 模式不支持 StreamOptions
|
||||
if relayconstant.RelayModeResponses == info.RelayMode {
|
||||
info.SupportStreamOptions = false
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
@@ -283,3 +313,22 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
type TaskSubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TaskInfo struct {
|
||||
Code int `json:"code"`
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
Url string `json:"url,omitempty"`
|
||||
Progress string `json:"progress,omitempty"`
|
||||
}
|
||||
|
||||
@@ -38,10 +38,16 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
}
|
||||
if info.ReturnDocuments {
|
||||
var document any
|
||||
if result.Document == "" {
|
||||
document = info.Documents[result.Index]
|
||||
} else {
|
||||
document = result.Document
|
||||
if result.Document != nil {
|
||||
if doc, ok := result.Document.(string); ok {
|
||||
if doc == "" {
|
||||
document = info.Documents[result.Index]
|
||||
} else {
|
||||
document = doc
|
||||
}
|
||||
} else {
|
||||
document = result.Document
|
||||
}
|
||||
}
|
||||
respResult.Document = document
|
||||
}
|
||||
|
||||
@@ -38,6 +38,12 @@ const (
|
||||
RelayModeSunoFetchByID
|
||||
RelayModeSunoSubmit
|
||||
|
||||
RelayModeKlingFetchByID
|
||||
RelayModeKlingSubmit
|
||||
|
||||
RelayModeJimengFetchByID
|
||||
RelayModeJimengSubmit
|
||||
|
||||
RelayModeRerank
|
||||
|
||||
RelayModeResponses
|
||||
@@ -77,7 +83,7 @@ func Path2RelayMode(path string) int {
|
||||
relayMode = RelayModeRerank
|
||||
} else if strings.HasPrefix(path, "/v1/realtime") {
|
||||
relayMode = RelayModeRealtime
|
||||
} else if strings.HasPrefix(path, "/v1beta/models") {
|
||||
} else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
|
||||
relayMode = RelayModeGemini
|
||||
}
|
||||
return relayMode
|
||||
@@ -133,3 +139,23 @@ func Path2RelaySuno(method, path string) int {
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelayKling(method, path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||
relayMode = RelayModeKlingSubmit
|
||||
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||
relayMode = RelayModeKlingFetchByID
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
func Path2RelayJimeng(method, path string) int {
|
||||
relayMode := RelayModeUnknown
|
||||
if method == http.MethodPost && strings.HasSuffix(path, "/video/generations") {
|
||||
relayMode = RelayModeJimengSubmit
|
||||
} else if method == http.MethodGet && strings.Contains(path, "/video/generations/") {
|
||||
relayMode = RelayModeJimengFetchByID
|
||||
}
|
||||
return relayMode
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||
return token
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embed
|
||||
}
|
||||
|
||||
func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoEmbedding(c)
|
||||
|
||||
var embeddingRequest *dto.EmbeddingRequest
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingRequest)
|
||||
@@ -47,13 +47,11 @@ func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
embeddingRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptToken := getEmbeddingPromptToken(*embeddingRequest)
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/model_setting"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -59,7 +60,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
@@ -71,9 +72,36 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
||||
}
|
||||
|
||||
inputText := strings.Join(inputTexts, "\n")
|
||||
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||
info.PromptTokens = inputTokens
|
||||
return inputTokens, err
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func isNoThinkingRequest(req *gemini.GeminiChatRequest) bool {
|
||||
if req.GenerationConfig.ThinkingConfig != nil && req.GenerationConfig.ThinkingConfig.ThinkingBudget != nil {
|
||||
return *req.GenerationConfig.ThinkingConfig.ThinkingBudget <= 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func trimModelThinking(modelName string) string {
|
||||
// 去除模型名称中的 -nothinking 后缀
|
||||
if strings.HasSuffix(modelName, "-nothinking") {
|
||||
return strings.TrimSuffix(modelName, "-nothinking")
|
||||
}
|
||||
// 去除模型名称中的 -thinking 后缀
|
||||
if strings.HasSuffix(modelName, "-thinking") {
|
||||
return strings.TrimSuffix(modelName, "-thinking")
|
||||
}
|
||||
|
||||
// 去除模型名称中的 -thinking-number
|
||||
if strings.Contains(modelName, "-thinking-") {
|
||||
parts := strings.Split(modelName, "-thinking-")
|
||||
if len(parts) > 1 {
|
||||
return parts[0] + "-thinking"
|
||||
}
|
||||
}
|
||||
return modelName
|
||||
}
|
||||
|
||||
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -83,7 +111,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoGemini(c)
|
||||
|
||||
// 检查 Gemini 流式模式
|
||||
checkGeminiStreamMode(c, relayInfo)
|
||||
@@ -97,7 +125,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
|
||||
// model mapped 模型映射
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||
}
|
||||
@@ -106,13 +134,28 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
promptTokens := value.(int)
|
||||
relayInfo.SetPromptTokens(promptTokens)
|
||||
} else {
|
||||
promptTokens, err := getGeminiInputTokens(req, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||
}
|
||||
promptTokens := getGeminiInputTokens(req, relayInfo)
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
if isNoThinkingRequest(req) {
|
||||
// check is thinking
|
||||
if !strings.Contains(relayInfo.OriginModelName, "-nothinking") {
|
||||
// try to get no thinking model price
|
||||
noThinkingModelName := relayInfo.OriginModelName + "-nothinking"
|
||||
containPrice := helper.ContainPriceOrRatio(noThinkingModelName)
|
||||
if containPrice {
|
||||
relayInfo.OriginModelName = noThinkingModelName
|
||||
relayInfo.UpstreamModelName = noThinkingModelName
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.GenerationConfig.ThinkingConfig == nil {
|
||||
gemini.ThinkingAdaptor(req, relayInfo)
|
||||
}
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
@@ -155,14 +198,33 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
println("Gemini request body: %s", string(requestBody))
|
||||
}
|
||||
|
||||
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
var httpResp *http.Response
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
openaiErr = service.RelayErrorHandler(httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
|
||||
if openaiErr != nil {
|
||||
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
@@ -4,12 +4,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
common2 "one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/common"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
||||
func ModelMappedHelper(c *gin.Context, info *common.RelayInfo, request any) error {
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" && modelMapping != "{}" {
|
||||
@@ -50,5 +52,41 @@ func ModelMappedHelper(c *gin.Context, info *common.RelayInfo) error {
|
||||
info.UpstreamModelName = currentModel
|
||||
}
|
||||
}
|
||||
if request != nil {
|
||||
switch info.RelayFormat {
|
||||
case common.RelayFormatGemini:
|
||||
// Gemini 模型映射
|
||||
case common.RelayFormatClaude:
|
||||
if claudeRequest, ok := request.(*dto.ClaudeRequest); ok {
|
||||
claudeRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIResponses:
|
||||
if openAIResponsesRequest, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
||||
openAIResponsesRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIAudio:
|
||||
if openAIAudioRequest, ok := request.(*dto.AudioRequest); ok {
|
||||
openAIAudioRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatOpenAIImage:
|
||||
if imageRequest, ok := request.(*dto.ImageRequest); ok {
|
||||
imageRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatRerank:
|
||||
if rerankRequest, ok := request.(*dto.RerankRequest); ok {
|
||||
rerankRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
case common.RelayFormatEmbedding:
|
||||
if embeddingRequest, ok := request.(*dto.EmbeddingRequest); ok {
|
||||
embeddingRequest.Model = info.UpstreamModelName
|
||||
}
|
||||
default:
|
||||
if openAIRequest, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
||||
openAIRequest.Model = info.UpstreamModelName
|
||||
} else {
|
||||
common2.LogWarn(c, fmt.Sprintf("model mapped but request type %T not supported", request))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,12 +5,17 @@ import (
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type GroupRatioInfo struct {
|
||||
GroupRatio float64
|
||||
GroupSpecialRatio float64
|
||||
HasSpecialRatio bool
|
||||
}
|
||||
|
||||
type PriceData struct {
|
||||
ModelPrice float64
|
||||
ModelRatio float64
|
||||
@@ -18,23 +23,51 @@ type PriceData struct {
|
||||
CacheRatio float64
|
||||
CacheCreationRatio float64
|
||||
ImageRatio float64
|
||||
GroupRatio float64
|
||||
UserGroupRatio float64
|
||||
UsePrice bool
|
||||
ShouldPreConsumedQuota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, ShouldPreConsumedQuota: %d, ImageRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.ShouldPreConsumedQuota, p.ImageRatio)
|
||||
}
|
||||
|
||||
// HandleGroupRatio checks for "auto_group" in the context and updates the group ratio and relayInfo.UsingGroup if present
|
||||
func HandleGroupRatio(ctx *gin.Context, relayInfo *relaycommon.RelayInfo) GroupRatioInfo {
|
||||
groupRatioInfo := GroupRatioInfo{
|
||||
GroupRatio: 1.0, // default ratio
|
||||
GroupSpecialRatio: -1,
|
||||
}
|
||||
|
||||
// check auto group
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("final group: %s", autoGroup))
|
||||
}
|
||||
relayInfo.UsingGroup = autoGroup.(string)
|
||||
}
|
||||
|
||||
// check user group special ratio
|
||||
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
||||
if ok {
|
||||
// user group special ratio
|
||||
groupRatioInfo.GroupSpecialRatio = userGroupRatio
|
||||
groupRatioInfo.GroupRatio = userGroupRatio
|
||||
groupRatioInfo.HasSpecialRatio = true
|
||||
} else {
|
||||
// normal group ratio
|
||||
groupRatioInfo.GroupRatio = ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
}
|
||||
|
||||
return groupRatioInfo
|
||||
}
|
||||
|
||||
func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens int, maxTokens int) (PriceData, error) {
|
||||
modelPrice, usePrice := operation_setting.GetModelPrice(info.OriginModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(info.Group)
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(info.UserGroup, info.Group)
|
||||
if ok {
|
||||
groupRatio = userGroupRatio
|
||||
}
|
||||
modelPrice, usePrice := ratio_setting.GetModelPrice(info.OriginModelName, false)
|
||||
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
var preConsumedQuota int
|
||||
var modelRatio float64
|
||||
var completionRatio float64
|
||||
@@ -47,7 +80,7 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
preConsumedTokens = promptTokens + maxTokens
|
||||
}
|
||||
var success bool
|
||||
modelRatio, success = operation_setting.GetModelRatio(info.OriginModelName)
|
||||
modelRatio, success = ratio_setting.GetModelRatio(info.OriginModelName)
|
||||
if !success {
|
||||
acceptUnsetRatio := false
|
||||
if accept, ok := info.UserSetting[constant2.UserAcceptUnsetRatioModel]; ok {
|
||||
@@ -60,22 +93,21 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
return PriceData{}, fmt.Errorf("模型 %s 倍率或价格未配置,请联系管理员设置或开始自用模式;Model %s ratio or price not set, please set or start self-use mode", info.OriginModelName, info.OriginModelName)
|
||||
}
|
||||
}
|
||||
completionRatio = operation_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = operation_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = operation_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = operation_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatio
|
||||
completionRatio = ratio_setting.GetCompletionRatio(info.OriginModelName)
|
||||
cacheRatio, _ = ratio_setting.GetCacheRatio(info.OriginModelName)
|
||||
cacheCreationRatio, _ = ratio_setting.GetCreateCacheRatio(info.OriginModelName)
|
||||
imageRatio, _ = ratio_setting.GetImageRatio(info.OriginModelName)
|
||||
ratio := modelRatio * groupRatioInfo.GroupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
}
|
||||
|
||||
priceData := PriceData{
|
||||
ModelPrice: modelPrice,
|
||||
ModelRatio: modelRatio,
|
||||
CompletionRatio: completionRatio,
|
||||
GroupRatio: groupRatio,
|
||||
UserGroupRatio: userGroupRatio,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
UsePrice: usePrice,
|
||||
CacheRatio: cacheRatio,
|
||||
ImageRatio: imageRatio,
|
||||
@@ -90,12 +122,41 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
|
||||
return priceData, nil
|
||||
}
|
||||
|
||||
type PerCallPriceData struct {
|
||||
ModelPrice float64
|
||||
Quota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
|
||||
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) PerCallPriceData {
|
||||
groupRatioInfo := HandleGroupRatio(c, info)
|
||||
|
||||
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
|
||||
priceData := PerCallPriceData{
|
||||
ModelPrice: modelPrice,
|
||||
Quota: quota,
|
||||
GroupRatioInfo: groupRatioInfo,
|
||||
}
|
||||
return priceData
|
||||
}
|
||||
|
||||
func ContainPriceOrRatio(modelName string) bool {
|
||||
_, ok := operation_setting.GetModelPrice(modelName, false)
|
||||
_, ok := ratio_setting.GetModelPrice(modelName, false)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
_, ok = operation_setting.GetModelRatio(modelName)
|
||||
_, ok = ratio_setting.GetModelRatio(modelName)
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
|
||||
"one-api/relay/constant"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -44,6 +46,11 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
if info.ApiType == constant.APITypeVolcEngine {
|
||||
watermark := formData.Has("watermark")
|
||||
imageRequest.Watermark = &watermark
|
||||
}
|
||||
default:
|
||||
err := common.UnmarshalBodyReusable(c, imageRequest)
|
||||
if err != nil {
|
||||
@@ -102,7 +109,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
|
||||
}
|
||||
|
||||
func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
relayInfo := relaycommon.GenRelayInfoImage(c)
|
||||
|
||||
imageRequest, err := getAndValidImageRequest(c, relayInfo)
|
||||
if err != nil {
|
||||
@@ -110,13 +117,11 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, imageRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
imageRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, len(imageRequest.Prompt), 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
@@ -162,7 +167,7 @@ func ImageHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
||||
|
||||
// reset model price
|
||||
priceData.ModelPrice *= sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatio * common.QuotaPerUnit)
|
||||
quota = int(priceData.ModelPrice * priceData.GroupRatioInfo.GroupRatio * common.QuotaPerUnit)
|
||||
userQuota, err = model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -174,18 +174,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||||
}
|
||||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
@@ -193,9 +184,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
if userQuota-priceData.Quota < 0 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
@@ -210,26 +200,18 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
}
|
||||
defer func() {
|
||||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
//err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, constant.MjActionSwapFace)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
}
|
||||
}()
|
||||
midjResponse := &mjResp.Response
|
||||
@@ -250,7 +232,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
Quota: priceData.Quota,
|
||||
}
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
@@ -480,18 +462,9 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
// 如果没有配置价格,则使用默认价格
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
modelPrice = defaultPrice
|
||||
}
|
||||
}
|
||||
groupRatio := setting.GetGroupRatio(group)
|
||||
ratio := modelPrice * groupRatio
|
||||
|
||||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||||
|
||||
userQuota, err := model.GetUserQuota(userId, false)
|
||||
if err != nil {
|
||||
return &dto.MidjourneyResponse{
|
||||
@@ -499,9 +472,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
Description: err.Error(),
|
||||
}
|
||||
}
|
||||
quota := int(ratio * common.QuotaPerUnit)
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
if consumeQuota && userQuota-priceData.Quota < 0 {
|
||||
return &dto.MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
@@ -516,22 +488,17 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
|
||||
defer func() {
|
||||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||||
err := service.PostConsumeQuota(relayInfo, quota, 0, true)
|
||||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", modelPrice, groupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||||
other := service.GenerateMjOtherInfo(priceData)
|
||||
model.RecordConsumeLog(c, userId, channelId, 0, 0, modelName, tokenName,
|
||||
priceData.Quota, logContent, tokenId, userQuota, 0, false, group, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
|
||||
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -559,7 +526,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
ChannelId: c.GetInt("channel_id"),
|
||||
Quota: quota,
|
||||
Quota: priceData.Quota,
|
||||
}
|
||||
if midjResponse.Code == 3 {
|
||||
//无实例账号自动禁用渠道(No available account instance)
|
||||
|
||||
@@ -90,15 +90,16 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
// get & validate textRequest 获取并验证文本请求
|
||||
textRequest, err := getAndValidateTextRequest(c, relayInfo)
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
|
||||
return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if textRequest.WebSearchOptions != nil {
|
||||
c.Set("chat_completion_web_search_context_size", textRequest.WebSearchOptions.SearchContextSize)
|
||||
}
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
words, err := checkRequestSensitive(textRequest, relayInfo)
|
||||
if err != nil {
|
||||
@@ -107,13 +108,11 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, textRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
textRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
// 获取 promptTokens,如果上下文中已经存在,则直接使用
|
||||
var promptTokens int
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
@@ -252,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
||||
case relayconstant.RelayModeChatCompletions:
|
||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||
case relayconstant.RelayModeCompletions:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case relayconstant.RelayModeModerations:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
case relayconstant.RelayModeEmbeddings:
|
||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
default:
|
||||
err = errors.New("unknown relay mode")
|
||||
promptTokens = 0
|
||||
@@ -361,9 +360,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
cacheRatio := priceData.CacheRatio
|
||||
imageRatio := priceData.ImageRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
userGroupRatio := priceData.UserGroupRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
@@ -511,7 +509,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, userGroupRatio)
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
other["image_ratio"] = imageRatio
|
||||
@@ -543,5 +541,5 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
other["audio_input_price"] = audioInputPrice
|
||||
}
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"one-api/relay/channel/palm"
|
||||
"one-api/relay/channel/perplexity"
|
||||
"one-api/relay/channel/siliconflow"
|
||||
"one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
@@ -101,6 +103,10 @@ func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
|
||||
// return &aiproxy.Adaptor{}
|
||||
case commonconstant.TaskPlatformSuno:
|
||||
return &suno.TaskAdaptor{}
|
||||
case commonconstant.TaskPlatformKling:
|
||||
return &kling.TaskAdaptor{}
|
||||
case commonconstant.TaskPlatformJimeng:
|
||||
return &jimeng.TaskAdaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -15,8 +14,9 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -38,9 +38,12 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
||||
modelPrice, success := operation_setting.GetModelPrice(modelName, true)
|
||||
if platform == constant.TaskPlatformKling {
|
||||
modelName = relayInfo.OriginModelName
|
||||
}
|
||||
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
||||
if !success {
|
||||
defaultPrice, ok := operation_setting.GetDefaultModelRatioMap()[modelName]
|
||||
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
||||
if !ok {
|
||||
modelPrice = 0.1
|
||||
} else {
|
||||
@@ -49,8 +52,14 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
// 预扣
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
ratio := modelPrice * groupRatio
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
var ratio float64
|
||||
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
||||
if hasUserGroupRatio {
|
||||
ratio = modelPrice * userGroupRatio
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
@@ -119,12 +128,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
|
||||
gRatio := groupRatio
|
||||
if hasUserGroupRatio {
|
||||
gRatio = userGroupRatio
|
||||
}
|
||||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = modelPrice
|
||||
other["group_ratio"] = groupRatio
|
||||
if hasUserGroupRatio {
|
||||
other["user_group_ratio"] = userGroupRatio
|
||||
}
|
||||
model.RecordConsumeLog(c, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
|
||||
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.UsingGroup, other)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
@@ -137,10 +153,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
relayInfo.ConsumeQuota = true
|
||||
// insert task
|
||||
task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
|
||||
task := model.InitTask(platform, relayInfo)
|
||||
task.TaskID = taskID
|
||||
task.Quota = quota
|
||||
task.Data = taskData
|
||||
task.Action = relayInfo.Action
|
||||
err = task.Insert()
|
||||
if err != nil {
|
||||
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
||||
@@ -150,8 +167,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
||||
}
|
||||
|
||||
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
||||
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
||||
relayconstant.RelayModeKlingFetchByID: videoFetchByIDRespBodyBuilder,
|
||||
}
|
||||
|
||||
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||
@@ -226,6 +244,27 @@ func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dt
|
||||
return
|
||||
}
|
||||
|
||||
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
||||
taskId := c.Param("task_id")
|
||||
userId := c.GetInt("id")
|
||||
|
||||
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
||||
if err != nil {
|
||||
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !exist {
|
||||
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
||||
return &dto.TaskDto{
|
||||
TaskID: task.TaskID,
|
||||
|
||||
@@ -14,12 +14,10 @@ import (
|
||||
)
|
||||
|
||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||
for _, document := range rerankRequest.Documents {
|
||||
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
||||
if err == nil {
|
||||
token += tkm
|
||||
}
|
||||
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||
token += tkm
|
||||
}
|
||||
return token
|
||||
}
|
||||
@@ -42,13 +40,11 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
|
||||
return service.OpenAIErrorWrapperLocal(fmt.Errorf("documents is empty"), "invalid_documents", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, rerankRequest)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
rerankRequest.Model = relayInfo.UpstreamModelName
|
||||
|
||||
promptToken := getRerankPromptToken(*rerankRequest)
|
||||
relayInfo.PromptTokens = promptToken
|
||||
|
||||
@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
|
||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
|
||||
inputTokens := service.CountTokenInput(req.Input, req.Model)
|
||||
info.PromptTokens = inputTokens
|
||||
return inputTokens, err
|
||||
return inputTokens
|
||||
}
|
||||
|
||||
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -63,19 +63,16 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
err = helper.ModelMappedHelper(c, relayInfo)
|
||||
err = helper.ModelMappedHelper(c, relayInfo, req)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||
}
|
||||
req.Model = relayInfo.UpstreamModelName
|
||||
|
||||
if value, exists := c.Get("prompt_tokens"); exists {
|
||||
promptTokens := value.(int)
|
||||
relayInfo.SetPromptTokens(promptTokens)
|
||||
} else {
|
||||
promptTokens, err := getInputTokens(req, relayInfo)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||
}
|
||||
promptTokens := getInputTokens(req, relayInfo)
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
@@ -6,12 +6,10 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
)
|
||||
|
||||
func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
@@ -39,43 +37,14 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
//isModelMapped = true
|
||||
}
|
||||
}
|
||||
//relayInfo.UpstreamModelName = textRequest.Model
|
||||
modelPrice, getModelPriceSuccess := operation_setting.GetModelPrice(relayInfo.UpstreamModelName, false)
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
|
||||
var preConsumedQuota int
|
||||
var ratio float64
|
||||
var modelRatio float64
|
||||
//err := service.SensitiveWordsCheck(textRequest)
|
||||
|
||||
//if constant.ShouldCheckPromptSensitive() {
|
||||
// err = checkRequestSensitive(textRequest, relayInfo)
|
||||
// if err != nil {
|
||||
// return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
|
||||
// }
|
||||
//}
|
||||
|
||||
//promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo)
|
||||
//// count messages token error 计算promptTokens错误
|
||||
//if err != nil {
|
||||
// return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
|
||||
//}
|
||||
//
|
||||
if !getModelPriceSuccess {
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
//if realtimeEvent.Session.MaxResponseOutputTokens != 0 {
|
||||
// preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens)
|
||||
//}
|
||||
modelRatio, _ = operation_setting.GetModelRatio(relayInfo.UpstreamModelName)
|
||||
ratio = modelRatio * groupRatio
|
||||
preConsumedQuota = int(float64(preConsumedTokens) * ratio)
|
||||
} else {
|
||||
preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio)
|
||||
relayInfo.UsePrice = true
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// pre-consume quota 预消耗配额
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo)
|
||||
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if openaiErr != nil {
|
||||
return openaiErr
|
||||
}
|
||||
@@ -113,6 +82,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
|
||||
return openaiErr
|
||||
}
|
||||
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
|
||||
userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "")
|
||||
userQuota, priceData, "")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
|
||||
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
|
||||
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
|
||||
apiRouter.GET("/ratio_config", middleware.CriticalRateLimit(), controller.GetRatioConfig)
|
||||
|
||||
userRoute := apiRouter.Group("/user")
|
||||
{
|
||||
@@ -83,6 +84,12 @@ func SetApiRouter(router *gin.Engine) {
|
||||
optionRoute.POST("/rest_model_ratio", controller.ResetModelRatio)
|
||||
optionRoute.POST("/migrate_console_setting", controller.MigrateConsoleSetting) // 用于迁移检测的旧键,下个版本会删除
|
||||
}
|
||||
ratioSyncRoute := apiRouter.Group("/ratio_sync")
|
||||
ratioSyncRoute.Use(middleware.RootAuth())
|
||||
{
|
||||
ratioSyncRoute.GET("/channels", controller.GetSyncableChannels)
|
||||
ratioSyncRoute.POST("/fetch", controller.FetchUpstreamRatios)
|
||||
}
|
||||
channelRoute := apiRouter.Group("/channel")
|
||||
channelRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
@@ -118,6 +125,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
tokenRoute.POST("/", controller.AddToken)
|
||||
tokenRoute.PUT("/", controller.UpdateToken)
|
||||
tokenRoute.DELETE("/:id", controller.DeleteToken)
|
||||
tokenRoute.POST("/batch", controller.DeleteTokenBatch)
|
||||
}
|
||||
redemptionRoute := apiRouter.Group("/redemption")
|
||||
redemptionRoute.Use(middleware.AdminAuth())
|
||||
|
||||
@@ -14,6 +14,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
SetApiRouter(router)
|
||||
SetDashboardRouter(router)
|
||||
SetRelayRouter(router)
|
||||
SetVideoRouter(router)
|
||||
frontendBaseUrl := os.Getenv("FRONTEND_BASE_URL")
|
||||
if common.IsMasterNode && frontendBaseUrl != "" {
|
||||
frontendBaseUrl = ""
|
||||
|
||||
@@ -63,6 +63,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
httpRouter.POST("/moderations", controller.Relay)
|
||||
httpRouter.POST("/rerank", controller.Relay)
|
||||
httpRouter.POST("/models/*path", controller.Relay)
|
||||
}
|
||||
|
||||
relayMjRouter := router.Group("/mj")
|
||||
|
||||
24
router/video-router.go
Normal file
24
router/video-router.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func SetVideoRouter(router *gin.Engine) {
|
||||
videoV1Router := router.Group("/v1")
|
||||
videoV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
videoV1Router.POST("/video/generations", controller.RelayTask)
|
||||
videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
|
||||
}
|
||||
|
||||
klingV1Router := router.Group("/kling/v1")
|
||||
klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
klingV1Router.POST("/videos/text2video", controller.RelayTask)
|
||||
klingV1Router.POST("/videos/image2video", controller.RelayTask)
|
||||
}
|
||||
}
|
||||
@@ -59,6 +59,8 @@ func ShouldDisableChannel(channelType int, err *dto.OpenAIErrorWithStatusCode) b
|
||||
return true
|
||||
case "billing_not_active":
|
||||
return true
|
||||
case "pre_consume_token_quota_failed":
|
||||
return true
|
||||
}
|
||||
switch err.Error.Type {
|
||||
case "insufficient_quota":
|
||||
|
||||
@@ -276,12 +276,15 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
||||
}
|
||||
if info.Done {
|
||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||
if info.ClaudeConvertInfo.Usage != nil {
|
||||
oaiUsage := info.ClaudeConvertInfo.Usage
|
||||
if oaiUsage != nil {
|
||||
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
|
||||
Type: "message_delta",
|
||||
Usage: &dto.ClaudeUsage{
|
||||
InputTokens: info.ClaudeConvertInfo.Usage.PromptTokens,
|
||||
OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
|
||||
InputTokens: oaiUsage.PromptTokens,
|
||||
OutputTokens: oaiUsage.CompletionTokens,
|
||||
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
|
||||
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
|
||||
},
|
||||
Delta: &dto.ClaudeMediaMessage{
|
||||
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
|
||||
|
||||
@@ -29,9 +29,11 @@ func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int)
|
||||
func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
lowerText := strings.ToLower(text)
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
if !strings.HasPrefix(lowerText, "get file base64 from url") && !strings.HasPrefix(lowerText, "mime type is not supported") {
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
openAIError := dto.OpenAIError{
|
||||
Message: text,
|
||||
@@ -53,9 +55,11 @@ func OpenAIErrorWrapperLocal(err error, code string, statusCode int) *dto.OpenAI
|
||||
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode {
|
||||
text := err.Error()
|
||||
lowerText := strings.ToLower(text)
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
if !strings.HasPrefix(lowerText, "get file base64 from url") {
|
||||
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") {
|
||||
common.SysLog(fmt.Sprintf("error: %s", text))
|
||||
text = "请求上游地址失败"
|
||||
}
|
||||
}
|
||||
claudeError := dto.ClaudeError{
|
||||
Message: text,
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
@@ -30,9 +32,104 @@ func GetFileBase64FromUrl(url string) (*dto.LocalFileData, error) {
|
||||
// Convert to base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
|
||||
|
||||
mimeType := resp.Header.Get("Content-Type")
|
||||
if len(strings.Split(mimeType, ";")) > 1 {
|
||||
// If Content-Type has parameters, take the first part
|
||||
mimeType = strings.Split(mimeType, ";")[0]
|
||||
}
|
||||
if mimeType == "application/octet-stream" {
|
||||
if common.DebugEnabled {
|
||||
println("MIME type is application/octet-stream, trying to guess from URL or filename")
|
||||
}
|
||||
// try to guess the MIME type from the url last segment
|
||||
urlParts := strings.Split(url, "/")
|
||||
if len(urlParts) > 0 {
|
||||
lastSegment := urlParts[len(urlParts)-1]
|
||||
if strings.Contains(lastSegment, ".") {
|
||||
// Extract the file extension
|
||||
filename := strings.Split(lastSegment, ".")
|
||||
if len(filename) > 1 {
|
||||
ext := strings.ToLower(filename[len(filename)-1])
|
||||
// Guess MIME type based on file extension
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// try to guess the MIME type from the file extension
|
||||
fileName := resp.Header.Get("Content-Disposition")
|
||||
if fileName != "" {
|
||||
// Extract the filename from the Content-Disposition header
|
||||
parts := strings.Split(fileName, ";")
|
||||
for _, part := range parts {
|
||||
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
|
||||
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
|
||||
// Remove quotes if present
|
||||
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
|
||||
fileName = fileName[1 : len(fileName)-1]
|
||||
}
|
||||
// Guess MIME type based on file extension
|
||||
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
|
||||
mimeType = GetMimeTypeByExtension(ext)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.LocalFileData{
|
||||
Base64Data: base64Data,
|
||||
MimeType: resp.Header.Get("Content-Type"),
|
||||
MimeType: mimeType,
|
||||
Size: int64(len(fileBytes)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GetMimeTypeByExtension(ext string) string {
|
||||
// Convert to lowercase for case-insensitive comparison
|
||||
ext = strings.ToLower(ext)
|
||||
switch ext {
|
||||
// Text files
|
||||
case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
|
||||
return "text/plain"
|
||||
|
||||
// Image files
|
||||
case "jpg", "jpeg":
|
||||
return "image/jpeg"
|
||||
case "png":
|
||||
return "image/png"
|
||||
case "gif":
|
||||
return "image/gif"
|
||||
|
||||
// Audio files
|
||||
case "mp3":
|
||||
return "audio/mp3"
|
||||
case "wav":
|
||||
return "audio/wav"
|
||||
case "mpeg":
|
||||
return "audio/mpeg"
|
||||
|
||||
// Video files
|
||||
case "mp4":
|
||||
return "video/mp4"
|
||||
case "wmv":
|
||||
return "video/wmv"
|
||||
case "flv":
|
||||
return "video/flv"
|
||||
case "mov":
|
||||
return "video/mov"
|
||||
case "mpg":
|
||||
return "video/mpg"
|
||||
case "avi":
|
||||
return "video/avi"
|
||||
case "mpegps":
|
||||
return "video/mpegps"
|
||||
|
||||
// Document files
|
||||
case "pdf":
|
||||
return "application/pdf"
|
||||
|
||||
default:
|
||||
return "application/octet-stream" // Default for unknown types
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -63,3 +64,13 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
info["cache_creation_ratio"] = cacheCreationRatio
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = priceData.ModelPrice
|
||||
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
||||
if priceData.GroupRatioInfo.HasSpecialRatio {
|
||||
other["user_group_ratio"] = priceData.GroupRatioInfo.GroupSpecialRatio
|
||||
}
|
||||
return other
|
||||
}
|
||||
|
||||
115
service/quota.go
115
service/quota.go
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"one-api/common"
|
||||
constant2 "one-api/constant"
|
||||
"one-api/dto"
|
||||
@@ -10,7 +12,7 @@ import (
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -45,9 +47,9 @@ func calculateAudioQuota(info QuotaInfo) int {
|
||||
return int(quota.IntPart())
|
||||
}
|
||||
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(info.ModelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(info.ModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(info.ModelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(info.ModelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(info.ModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(info.ModelName))
|
||||
|
||||
groupRatio := decimal.NewFromFloat(info.GroupRatio)
|
||||
modelRatio := decimal.NewFromFloat(info.ModelRatio)
|
||||
@@ -93,12 +95,21 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
textOutTokens := usage.OutputTokenDetails.TextTokens
|
||||
audioInputTokens := usage.InputTokenDetails.AudioTokens
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
groupRatio := setting.GetGroupRatio(relayInfo.Group)
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
groupRatio = userGroupRatio
|
||||
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
||||
modelRatio, _ := ratio_setting.GetModelRatio(modelName)
|
||||
|
||||
autoGroup, exists := ctx.Get("auto_group")
|
||||
if exists {
|
||||
groupRatio = ratio_setting.GetGroupRatio(autoGroup.(string))
|
||||
log.Printf("final group ratio: %f", groupRatio)
|
||||
relayInfo.UsingGroup = autoGroup.(string)
|
||||
}
|
||||
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
modelRatio, _ := operation_setting.GetModelRatio(modelName)
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
@@ -112,7 +123,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
ModelName: modelName,
|
||||
UsePrice: relayInfo.UsePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: groupRatio,
|
||||
GroupRatio: actualGroupRatio,
|
||||
}
|
||||
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
@@ -134,8 +145,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
}
|
||||
|
||||
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||
modelPrice float64, usePrice bool, extraContent string) {
|
||||
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
@@ -145,15 +155,15 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
audioOutTokens := usage.OutputTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(modelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(modelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(modelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
TextTokens: textInputTokens,
|
||||
@@ -166,7 +176,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
ModelName: modelName,
|
||||
UsePrice: usePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: actualGroupRatio,
|
||||
GroupRatio: groupRatio,
|
||||
}
|
||||
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
@@ -198,9 +208,9 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -214,15 +224,25 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := priceData.CompletionRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
userGroupRatio := priceData.UserGroupRatio
|
||||
cacheRatio := priceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
cacheCreationRatio := priceData.CacheCreationRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
if relayInfo.ChannelType == common.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
|
||||
if promptTokens >= maybeCacheCreationTokens {
|
||||
cacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
}
|
||||
promptTokens -= cacheCreationTokens
|
||||
}
|
||||
|
||||
calculateQuota := 0.0
|
||||
if !priceData.UsePrice {
|
||||
calculateQuota = float64(promptTokens)
|
||||
@@ -265,9 +285,30 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, userGroupRatio)
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, modelName,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
|
||||
if priceData.CacheCreationRatio == 1 {
|
||||
return 0
|
||||
}
|
||||
quotaPrice := priceData.ModelRatio / common.QuotaPerUnit
|
||||
promptCacheCreatePrice := quotaPrice * priceData.CacheCreationRatio
|
||||
promptCacheReadPrice := quotaPrice * priceData.CacheRatio
|
||||
completionPrice := quotaPrice * priceData.CompletionRatio
|
||||
|
||||
cost := usage.Cost
|
||||
totalPromptTokens := float64(usage.PromptTokens)
|
||||
completionTokens := float64(usage.CompletionTokens)
|
||||
promptCacheReadTokens := float64(usage.PromptTokensDetails.CachedTokens)
|
||||
|
||||
return int(math.Round((cost -
|
||||
totalPromptTokens*quotaPrice +
|
||||
promptCacheReadTokens*(quotaPrice-promptCacheReadPrice) -
|
||||
completionTokens*completionPrice) /
|
||||
(promptCacheCreatePrice - quotaPrice)))
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
@@ -281,21 +322,15 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
audioOutTokens := usage.CompletionTokenDetails.AudioTokens
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := decimal.NewFromFloat(operation_setting.GetCompletionRatio(relayInfo.OriginModelName))
|
||||
audioRatio := decimal.NewFromFloat(operation_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(operation_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
||||
completionRatio := decimal.NewFromFloat(ratio_setting.GetCompletionRatio(relayInfo.OriginModelName))
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
|
||||
actualGroupRatio := groupRatio
|
||||
userGroupRatio, ok := setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.Group)
|
||||
if ok {
|
||||
actualGroupRatio = userGroupRatio
|
||||
}
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
TextTokens: textInputTokens,
|
||||
@@ -308,7 +343,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
ModelName: relayInfo.OriginModelName,
|
||||
UsePrice: usePrice,
|
||||
ModelRatio: modelRatio,
|
||||
GroupRatio: actualGroupRatio,
|
||||
GroupRatio: groupRatio,
|
||||
}
|
||||
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
@@ -348,9 +383,9 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, userGroupRatio)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
|
||||
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.UsingGroup, other)
|
||||
}
|
||||
|
||||
func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||
|
||||
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
||||
toolTokens := CountTokenInput(countStr, request.Model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
||||
|
||||
// Count tokens in system message
|
||||
if request.System != "" {
|
||||
systemTokens, err := CountTokenInput(request.System, model)
|
||||
systemTokens := CountTokenInput(request.System, model)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
switch request.Type {
|
||||
case dto.RealtimeEventTypeSessionUpdate:
|
||||
if request.Session != nil {
|
||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||||
textToken += msgTokens
|
||||
}
|
||||
case dto.RealtimeEventResponseAudioDelta:
|
||||
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
audioToken += atk
|
||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||
// count text token
|
||||
tkm, err := CountTextToken(request.Delta, model)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
||||
}
|
||||
tkm := CountTextToken(request.Delta, model)
|
||||
textToken += tkm
|
||||
case dto.RealtimeEventInputAudioBufferAppend:
|
||||
// count audio token
|
||||
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
case "message":
|
||||
for _, content := range request.Item.Content {
|
||||
if content.Type == "input_text" {
|
||||
tokens, err := CountTextToken(content.Text, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
tokens := CountTextToken(content.Text, model)
|
||||
textToken += tokens
|
||||
}
|
||||
}
|
||||
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
if !info.IsFirstRequest {
|
||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||
for _, tool := range info.RealtimeTools {
|
||||
toolTokens, err := CountTokenInput(tool, model)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
toolTokens := CountTokenInput(tool, model)
|
||||
textToken += 8
|
||||
textToken += toolTokens
|
||||
}
|
||||
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
||||
return tokenNum, nil
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string) (int, error) {
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTextToken(v, model)
|
||||
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
|
||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||
tokens := 0
|
||||
for _, message := range messages {
|
||||
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||
tokens += tkm
|
||||
if message.Delta.ToolCalls != nil {
|
||||
for _, tool := range message.Delta.ToolCalls {
|
||||
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
||||
tkm := CountTokenInput(tool.Function.Name, model)
|
||||
tokens += tkm
|
||||
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
||||
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||
tokens += tkm
|
||||
}
|
||||
}
|
||||
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTTSToken(text string, model string) (int, error) {
|
||||
func CountTTSToken(text string, model string) int {
|
||||
if strings.HasPrefix(model, "tts") {
|
||||
return utf8.RuneCountInString(text), nil
|
||||
return utf8.RuneCountInString(text)
|
||||
} else {
|
||||
return CountTextToken(text, model)
|
||||
}
|
||||
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
||||
//}
|
||||
|
||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||
func CountTextToken(text string, model string) (int, error) {
|
||||
var err error
|
||||
func CountTextToken(text string, model string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text), err
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
// return 0, errors.New("unknown relay mode")
|
||||
//}
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm, err := CountTextToken(responseText, modeName)
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
usage.CompletionTokens = ctkm
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
return usage, err
|
||||
return usage
|
||||
}
|
||||
|
||||
func ValidUsage(usage *dto.Usage) bool {
|
||||
|
||||
31
setting/auto_group.go
Normal file
31
setting/auto_group.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package setting
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var AutoGroups = []string{
|
||||
"default",
|
||||
}
|
||||
|
||||
var DefaultUseAutoGroup = false
|
||||
|
||||
func ContainsAutoGroup(group string) bool {
|
||||
for _, autoGroup := range AutoGroups {
|
||||
if autoGroup == group {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func UpdateAutoGroupsByJsonString(jsonString string) error {
|
||||
AutoGroups = make([]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &AutoGroups)
|
||||
}
|
||||
|
||||
func AutoGroups2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(AutoGroups)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
@@ -17,6 +17,8 @@ const (
|
||||
const (
|
||||
// Gemini Audio Input Price
|
||||
Gemini25FlashPreviewInputAudioPrice = 1.00
|
||||
Gemini25FlashProductionInputAudioPrice = 1.00 // for `gemini-2.5-flash`
|
||||
Gemini25FlashLitePreviewInputAudioPrice = 0.50
|
||||
Gemini25FlashNativeAudioInputAudioPrice = 3.00
|
||||
Gemini20FlashInputAudioPrice = 0.70
|
||||
)
|
||||
@@ -64,10 +66,14 @@ func GetFileSearchPricePerThousand() float64 {
|
||||
}
|
||||
|
||||
func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 {
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") {
|
||||
return Gemini25FlashPreviewInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") {
|
||||
if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-native-audio") {
|
||||
return Gemini25FlashNativeAudioInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview-lite") {
|
||||
return Gemini25FlashLitePreviewInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash-preview") {
|
||||
return Gemini25FlashPreviewInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.5-flash") {
|
||||
return Gemini25FlashProductionInputAudioPrice
|
||||
} else if strings.HasPrefix(modelName, "gemini-2.0-flash") {
|
||||
return Gemini20FlashInputAudioPrice
|
||||
}
|
||||
|
||||
@@ -1,8 +1,45 @@
|
||||
package setting
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
var EpayId = ""
|
||||
var EpayKey = ""
|
||||
var Price = 7.3
|
||||
var MinTopUp = 1
|
||||
|
||||
var PayMethods = []map[string]string{
|
||||
{
|
||||
"name": "支付宝",
|
||||
"color": "rgba(var(--semi-blue-5), 1)",
|
||||
"type": "alipay",
|
||||
},
|
||||
{
|
||||
"name": "微信",
|
||||
"color": "rgba(var(--semi-green-5), 1)",
|
||||
"type": "wxpay",
|
||||
},
|
||||
}
|
||||
|
||||
func UpdatePayMethodsByJsonString(jsonString string) error {
|
||||
PayMethods = make([]map[string]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
}
|
||||
|
||||
func PayMethods2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(PayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
func ContainsPayMethod(method string) bool {
|
||||
for _, payMethod := range PayMethods {
|
||||
if payMethod["type"] == method {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package operation_setting
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -85,7 +85,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
|
||||
cacheRatioMapMutex.Lock()
|
||||
defer cacheRatioMapMutex.Unlock()
|
||||
cacheRatioMap = make(map[string]float64)
|
||||
return json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
|
||||
err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
|
||||
if err == nil {
|
||||
InvalidateExposedDataCache()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetCacheRatio returns the cache ratio for a model
|
||||
@@ -106,3 +110,13 @@ func GetCreateCacheRatio(name string) (float64, bool) {
|
||||
}
|
||||
return ratio, true
|
||||
}
|
||||
|
||||
func GetCacheRatioCopy() map[string]float64 {
|
||||
cacheRatioMapMutex.RLock()
|
||||
defer cacheRatioMapMutex.RUnlock()
|
||||
copyMap := make(map[string]float64, len(cacheRatioMap))
|
||||
for k, v := range cacheRatioMap {
|
||||
copyMap[k] = v
|
||||
}
|
||||
return copyMap
|
||||
}
|
||||
17
setting/ratio_setting/expose_ratio.go
Normal file
17
setting/ratio_setting/expose_ratio.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ratio_setting
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
var exposeRatioEnabled atomic.Bool
|
||||
|
||||
func init() {
|
||||
exposeRatioEnabled.Store(false)
|
||||
}
|
||||
|
||||
func SetExposeRatioEnabled(enabled bool) {
|
||||
exposeRatioEnabled.Store(enabled)
|
||||
}
|
||||
|
||||
func IsExposeRatioEnabled() bool {
|
||||
return exposeRatioEnabled.Load()
|
||||
}
|
||||
55
setting/ratio_setting/exposed_cache.go
Normal file
55
setting/ratio_setting/exposed_cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const exposedDataTTL = 30 * time.Second
|
||||
|
||||
type exposedCache struct {
|
||||
data gin.H
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
exposedData atomic.Value
|
||||
rebuildMu sync.Mutex
|
||||
)
|
||||
|
||||
func InvalidateExposedDataCache() {
|
||||
exposedData.Store((*exposedCache)(nil))
|
||||
}
|
||||
|
||||
func cloneGinH(src gin.H) gin.H {
|
||||
dst := make(gin.H, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func GetExposedData() gin.H {
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
rebuildMu.Lock()
|
||||
defer rebuildMu.Unlock()
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
newData := gin.H{
|
||||
"model_ratio": GetModelRatioCopy(),
|
||||
"completion_ratio": GetCompletionRatioCopy(),
|
||||
"cache_ratio": GetCacheRatioCopy(),
|
||||
"model_price": GetModelPriceCopy(),
|
||||
}
|
||||
exposedData.Store(&exposedCache{
|
||||
data: newData,
|
||||
expiresAt: time.Now().Add(exposedDataTTL),
|
||||
})
|
||||
return cloneGinH(newData)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user