Compare commits

...

21 Commits

Author SHA1 Message Date
CalciumIon
5f082d72bb update dockerignore 2024-12-27 20:49:58 +08:00
CalciumIon
0fd0e5d309 fix: oauth bind 2024-12-27 18:32:11 +08:00
CalciumIon
d2297d2723 feat: update o1 default token encoder 2024-12-27 15:03:10 +08:00
CalciumIon
62ae46b552 feat: support azure stream_options 2024-12-26 22:51:06 +08:00
CalciumIon
0b1354ed51 update model ratio 2024-12-26 16:03:22 +08:00
Calcium-Ion
132c71390c Merge pull request #661 from tenacioustommy/fix-title-schema
fix delete title schema
2024-12-26 14:27:07 +08:00
Calcium-Ion
bb3deb7b93 Merge pull request #662 from xqx333/main
fix 重试过程多次获取图片
2024-12-26 14:26:50 +08:00
CalciumIon
f92d96e298 fix: update render function for quota display in Detail page 2024-12-26 14:25:44 +08:00
xqx333
c86762b656 Update relay-text.go
在上下文中存入promptTokens,避免重试过程重复计算
2024-12-26 02:00:04 +08:00
tenacious
3409d7a6b6 fix delete title schema 2024-12-26 00:24:45 +08:00
CalciumIon
bfba4866a5 fix: validate number input in renderQuotaNumberWithDigit and improve data handling in Detail page
- Added input validation to ensure that the `num` parameter in `renderQuotaNumberWithDigit` is a valid number, returning 0 for invalid inputs.
- Updated the `Detail` component to use `datum['rawQuota']` instead of `datum['Usage']` for rendering quota values, ensuring more accurate data representation.
- Enhanced data aggregation logic to handle cases where quota values may be missing or invalid, improving overall data integrity in charts and tables.
- Removed unnecessary time granularity calculations and streamlined the data processing for better performance.
2024-12-25 23:16:35 +08:00
CalciumIon
4fc1fe318e refactor: migrate group ratio and user usable groups logic to new setting package
- Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services.
- Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability.
- Updated related functions to ensure consistent behavior with the new setting package integration.
2024-12-25 19:31:12 +08:00
CalciumIon
b3576f24ef fix typo 2024-12-25 18:44:45 +08:00
CalciumIon
ed4d26fc9e fix: update MaxCompletionTokens for model prefix handling in buildTestRequest function 2024-12-25 17:55:20 +08:00
CalciumIon
ba56e2e8ca fix: correct user retrieval in GetPricing function 2024-12-25 14:29:52 +08:00
CalciumIon
7c20e6d047 fix: resolve pricing calculation issue (#659) 2024-12-25 14:26:43 +08:00
CalciumIon
72d6898eb5 feat: Implement batch tagging functionality for channels
- Added a new endpoint to batch set tags for multiple channels, allowing users to update tags efficiently.
- Introduced a new `BatchSetChannelTag` function in the controller to handle incoming requests and validate parameters.
- Updated the `BatchSetChannelTag` method in the model to manage database transactions and ensure data integrity during tag updates.
- Enhanced the ChannelsTable component in the frontend to support batch tag setting, including UI elements for user interaction.
- Updated localization files to include new translation keys related to batch operations and tag settings.
2024-12-25 14:19:00 +08:00
CalciumIon
f2c9388139 fix: update searchUsers function to include searchKeyword and searchGroup parameters 2024-12-25 13:44:55 +08:00
Calcium-Ion
aaf5cecefd Merge pull request #656 from Yan-Zero/main
fix: gemini function call
2024-12-25 13:38:34 +08:00
Yan
a8a2195ab1 Merge branch 'Calcium-Ion:main' into main 2024-12-24 20:46:16 +08:00
Yan
d40e6ec25d fix: gemini func call 2024-12-24 20:46:02 +08:00
39 changed files with 634 additions and 275 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.github
.git
*.md
.vscode
.gitignore
Makefile

3
.gitignore vendored
View File

@@ -8,4 +8,5 @@ build
logs
web/dist
.env
one-api
one-api
.DS_Store

View File

@@ -356,7 +356,7 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
if strings.HasPrefix(name, "o1-") {
if strings.HasPrefix(name, "o1") {
return 4
}
if name == "chatgpt-4o-latest" {

View File

@@ -35,9 +35,7 @@ func StrToMap(str string) map[string]interface{} {
m := make(map[string]interface{})
err := json.Unmarshal([]byte(str), &m)
if err != nil {
return map[string]interface{}{
"result": str,
}
return nil
}
return m
}

View File

@@ -1,6 +1,9 @@
package constant
var (
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonLength = "length"
FinishReasonFunctionCall = "function_call"
FinishReasonContentFilter = "content_filter"
)

View File

@@ -152,8 +152,8 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later
Stream: false,
}
if strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
if strings.HasPrefix(model, "o1") {
testRequest.MaxCompletionTokens = 10
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
testRequest.MaxTokens = 2
} else {

View File

@@ -419,7 +419,8 @@ func EditTagChannels(c *gin.Context) {
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
@@ -570,3 +571,29 @@ func FetchModels(c *gin.Context) {
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}

View File

@@ -3,13 +3,13 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
for groupName, _ := range setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -24,9 +24,9 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio {
for groupName, _ := range setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup)
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName]
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
"strings"
"github.com/gin-gonic/gin"
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
err = common.CheckGroupRatio(option.Value)
err = setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -4,29 +4,28 @@ import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
"one-api/setting"
)
func GetPricing(c *gin.Context) {
pricing := model.GetPricing()
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := common.GroupRatio
groupRatio := map[string]float64{}
for s, f := range setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
if exists {
user, err := model.GetChannelById(userId.(int), false)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
user, err := model.GetUserById(userId.(int), false)
if err == nil {
group = user.Group
}
group = user.Group
}
usableGroup = common.GetUserUsableGroups(group)
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
for group := range common.GroupRatio {
for group := range setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}

View File

@@ -17,6 +17,7 @@ import (
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"strings"
)
@@ -83,7 +84,7 @@ func Playground(c *gin.Context) {
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"strconv"
"strings"
"time"
@@ -43,12 +44,12 @@ func Distribute() func(c *gin.Context) {
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
if !setting.ContainsGroupRatio(tokenGroup) {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}

View File

@@ -3,10 +3,11 @@ package model
import (
"errors"
"fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common"
"strings"
"github.com/samber/lo"
"gorm.io/gorm"
)
type Ability struct {
@@ -173,18 +174,67 @@ func (channel *Channel) DeleteAbilities() error {
// UpdateAbilities updates abilities of this channel.
// Make sure the channel is completed before calling this function.
func (channel *Channel) UpdateAbilities() error {
// A quick and dirty way to update abilities
func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
isNewTx := false
// 如果没有传入事务,创建新的事务
if tx == nil {
tx = DB.Begin()
if tx.Error != nil {
return tx.Error
}
isNewTx = true
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
}
// First delete all abilities of this channel
err := channel.DeleteAbilities()
err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
if err != nil {
if isNewTx {
tx.Rollback()
}
return err
}
// Then add new abilities
err = channel.AddAbilities()
if err != nil {
return err
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
ability := Ability{
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
Tag: channel.Tag,
}
abilities = append(abilities, ability)
}
}
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
err = tx.Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
}
return err
}
}
}
// 如果是新创建的事务,需要提交
if isNewTx {
return tx.Commit().Error
}
return nil
}
@@ -246,7 +296,7 @@ func FixAbility() (int, error) {
return 0, err
}
for _, channel := range channels {
err := channel.UpdateAbilities()
err := channel.UpdateAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {

View File

@@ -257,7 +257,7 @@ func (channel *Channel) Update() error {
return err
}
DB.Model(channel).First(channel, "id = ?", channel.Id)
err = channel.UpdateAbilities()
err = channel.UpdateAbilities(nil)
return err
}
@@ -389,7 +389,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
channels, err := GetChannelsByTag(updatedTag, false)
if err == nil {
for _, channel := range channels {
err = channel.UpdateAbilities()
err = channel.UpdateAbilities(nil)
if err != nil {
common.SysError("failed to update abilities: " + err.Error())
}
@@ -509,3 +509,42 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
}
channel.Setting = string(settingBytes)
}
func GetChannelsByIds(ids []int) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("id in (?)", ids).Find(&channels).Error
return channels, err
}
func BatchSetChannelTag(ids []int, tag *string) error {
// 开启事务
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
// 更新标签
err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
if err != nil {
tx.Rollback()
return err
}
// update ability status
channels, err := GetChannelsByIds(ids)
if err != nil {
tx.Rollback()
return err
}
for _, channel := range channels {
err = channel.UpdateAbilities(tx)
if err != nil {
tx.Rollback()
return err
}
}
// 提交事务
return tx.Commit().Error
}

View File

@@ -87,8 +87,8 @@ func InitOptionMap() {
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":

View File

@@ -4,7 +4,7 @@ type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
}
@@ -18,16 +18,39 @@ type FunctionCall struct {
Arguments any `json:"args"`
}
type GeminiFunctionResponseContent struct {
Name string `json:"name"`
Content any `json:"content"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response any `json:"response"`
Name string `json:"name"`
Response GeminiFunctionResponseContent `json:"response"`
}
type GeminiPartExecutableCode struct {
Language string `json:"language,omitempty"`
Code string `json:"code,omitempty"`
}
type GeminiPartCodeExecutionResult struct {
Outcome string `json:"outcome,omitempty"`
Output string `json:"output,omitempty"`
}
type GeminiFileData struct {
MimeType string `json:"mimeType,omitempty"`
FileUri string `json:"fileUri,omitempty"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
FileData *GeminiFileData `json:"fileData,omitempty"`
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
}
type GeminiChatContent struct {
@@ -40,9 +63,11 @@ type GeminiChatSafetySettings struct {
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
GoogleSearch any `json:"googleSearch,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
type GeminiChatTool struct {
GoogleSearch any `json:"googleSearch,omitempty"`
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
CodeExecution any `json:"codeExecution,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
@@ -54,11 +79,12 @@ type GeminiChatGenerationConfig struct {
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
FinishReason *string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}

View File

@@ -18,6 +18,7 @@ import (
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
@@ -46,16 +47,24 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
Seed: int64(textRequest.Seed),
},
}
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
if tool.Function.Name == "googleSearch" {
googleSearch = true
continue
}
if tool.Function.Name == "codeExecution" {
codeExecution = true
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
if ok {
@@ -68,25 +77,32 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
functions = append(functions, tool.Function)
}
if len(functions) > 0 {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: functions,
},
}
if codeExecution {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
CodeExecution: make(map[string]string),
})
}
if googleSearch {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTools{
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
GoogleSearch: make(map[string]string),
})
}
if len(functions) > 0 {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
FunctionDeclarations: functions,
})
}
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
} else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
geminiRequest.Tools = []GeminiChatTool{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
@@ -96,20 +112,14 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
}
tool_call_ids := make(map[string]string)
var system_content []string
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
if message.Role == "system" {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
},
}
system_content = append(system_content, message.StringContent())
continue
} else if message.Role == "tool" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
} else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "user",
})
@@ -121,9 +131,16 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
content := common.StrToMap(message.StringContent())
functionResp := &FunctionResponse{
Name: name,
Response: common.StrToMap(message.StringContent()),
Name: name,
Response: GeminiFunctionResponseContent{
Name: name,
Content: content,
},
}
if content == nil {
functionResp.Response.Content = message.StringContent()
}
*parts = append(*parts, GeminiPart{
FunctionResponse: functionResp,
@@ -134,57 +151,65 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
content := GeminiChatContent{
Role: message.Role,
}
isToolCall := false
// isToolCall := false
if message.ToolCalls != nil {
message.Role = "model"
isToolCall = true
// message.Role = "model"
// isToolCall = true
for _, call := range message.ParseToolCalls() {
args := map[string]interface{}{}
if call.Function.Arguments != "" {
if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
}
}
toolCall := GeminiPart{
FunctionCall: &FunctionCall{
FunctionName: call.Function.Name,
Arguments: call.Function.Parameters,
Arguments: args,
},
}
parts = append(parts, toolCall)
tool_call_ids[call.ID] = call.Function.Name
}
}
if !isToolCall {
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
if part.Text == "" {
continue
}
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "image/" + format,
Data: base64String,
},
})
}
}
}
@@ -197,6 +222,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
}
if len(system_content) > 0 {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
{
Text: strings.Join(system_content, "\n"),
},
},
}
}
return &geminiRequest, nil
}
@@ -209,12 +245,12 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
if !ok || len(v) == 0 {
return schema
}
// 删除所有的title字段
delete(v, "title")
// 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
}
delete(v, "title")
switch v["type"] {
case "object":
delete(v, "additionalProperties")
@@ -240,15 +276,15 @@ func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interfac
return v
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
// func (g *GeminiChatResponse) GetResponseText() string {
// if g == nil {
// return ""
// }
// if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
// return g.Candidates[0].Content.Parts[0].Text
// }
// return ""
// }
func getToolCall(item *GeminiPart) *dto.ToolCall {
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
@@ -298,11 +334,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
// jsonData, _ := json.MarshalIndent(candidate, "", " ")
// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
is_tool_call := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: i,
Index: int(candidate.Index),
Message: dto.Message{
Role: "assistant",
Content: content,
@@ -319,48 +354,107 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
tool_calls = append(tool_calls, *call)
}
} else {
texts = append(texts, part.Text)
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if len(tool_calls) > 0 {
choice.Message.SetToolCalls(tool_calls)
is_tool_call = true
}
// 过滤掉空行
choice.Message.SetStringContent(strings.Join(texts, "\n"))
choice.Message.SetToolCalls(tool_calls)
}
if candidate.FinishReason != nil {
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = constant.FinishReasonLength
default:
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if is_tool_call {
choice.FinishReason = constant.FinishReasonToolCalls
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
is_stop := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
is_stop = true
candidate.FinishReason = nil
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: int(candidate.Index),
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
},
}
var texts []string
var tool_calls []dto.ToolCall
for _, part := range geminiResponse.Candidates[0].Content.Parts {
if part.FunctionCall != nil {
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
}
} else {
texts = append(texts, part.Text)
isTools := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = &constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = &constant.FinishReasonLength
default:
choice.FinishReason = &constant.FinishReasonContentFilter
}
}
if len(texts) > 0 {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
isTools = true
if call := getToolCall(&part); call != nil {
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
} else {
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if len(tool_calls) > 0 {
choice.Delta.ToolCalls = tool_calls
choice.Delta.SetContentString(strings.Join(texts, "\n"))
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
choices = append(choices, choice)
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
response.Choices = choices
return &response, is_stop
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
// responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
@@ -384,14 +478,11 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
continue
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
responseText += response.Choices[0].Delta.GetContentString()
// responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
@@ -400,12 +491,17 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
if err != nil {
common.LogError(c, err.Error())
}
if is_stop {
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
}
}
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
var response *dto.ChatCompletionsStreamResponse
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
if info.ShouldIncludeUsage {
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)

View File

@@ -106,7 +106,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI {
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
request.StreamOptions = nil
}
if strings.HasPrefix(request.Model, "o1") {

View File

@@ -109,7 +109,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare {
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
info.SupportStreamOptions = true
}
return info

View File

@@ -74,7 +74,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)

View File

@@ -99,7 +99,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
modelPrice = 0.0025 * modelRatio
}
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
sizeRatio := 1.0

View File

@@ -168,7 +168,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
@@ -474,7 +474,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {

View File

@@ -94,7 +94,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
@@ -108,10 +108,17 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
}
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
// 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens = value.(int)
} else {
promptTokens, err = getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
c.Set("prompt_tokens", promptTokens)
}
if !getModelPriceSuccess {

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64

View File

@@ -16,6 +16,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
)
/*
@@ -48,7 +49,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil {

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
)
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64

View File

@@ -28,10 +28,10 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
userRoute := apiRouter.Group("/user")
{
@@ -99,7 +99,7 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/setting"
"strings"
"time"
)
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName)
ratio := groupRatio * modelRatio

View File

@@ -19,42 +19,40 @@ import (
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
var cl200kTokenEncoder *tiktoken.Tiktoken
var o200kTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
defaultTokenEncoder = cl100TokenEncoder
o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
}
for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
tokenEncoderMap[model] = cl100TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = cl200kTokenEncoder
tokenEncoderMap[model] = o200kTokenEncoder
} else {
tokenEncoderMap[model] = gpt4TokenEncoder
tokenEncoderMap[model] = defaultTokenEncoder
}
} else if strings.HasPrefix(model, "o1") {
tokenEncoderMap[model] = o200kTokenEncoder
} else {
tokenEncoderMap[model] = nil
tokenEncoderMap[model] = defaultTokenEncoder
}
}
common.SysLog("token encoders initialized")
}
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
return cl200kTokenEncoder
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
return o200kTokenEncoder
}
return defaultTokenEncoder
}
@@ -92,11 +90,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
}
// TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream {
return 1000, nil
return 256, nil
}
// 是否统计图片token
if !constant.GetMediaToken {
return 1000, nil
return 256, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {

View File

@@ -1,33 +1,47 @@
package common
package setting
import (
"encoding/json"
"errors"
"one-api/common"
)
var GroupRatio = map[string]float64{
var groupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
func GetGroupRatioCopy() map[string]float64 {
groupRatioCopy := make(map[string]float64)
for k, v := range groupRatio {
groupRatioCopy[k] = v
}
return groupRatioCopy
}
func ContainsGroupRatio(name string) bool {
_, ok := groupRatio[name]
return ok
}
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
jsonBytes, err := json.Marshal(groupRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
GroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
groupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &groupRatio)
}
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
ratio, ok := groupRatio[name]
if !ok {
SysError("group ratio not found: " + name)
common.SysError("group ratio not found: " + name)
return 1
}
return ratio

View File

@@ -1,7 +1,8 @@
package common
package setting
import (
"encoding/json"
"one-api/common"
)
var UserUsableGroups = map[string]string{
@@ -12,7 +13,7 @@ var UserUsableGroups = map[string]string{
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
common.SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}

View File

@@ -162,9 +162,15 @@ const ChannelsTable = () => {
return (
<div>
<Space spacing={2}>
{text?.split(',').map((item, index) => {
return renderGroup(item);
})}
{text?.split(',')
.sort((a, b) => {
if (a === 'default') return -1;
if (b === 'default') return 1;
return a.localeCompare(b);
})
.map((item, index) => {
return renderGroup(item);
})}
</Space>
</div>
);
@@ -507,6 +513,8 @@ const ChannelsTable = () => {
const [selectedChannels, setSelectedChannels] = useState([]);
const [showEditPriority, setShowEditPriority] = useState(false);
const [enableTagMode, setEnableTagMode] = useState(false);
const [showBatchSetTag, setShowBatchSetTag] = useState(false);
const [batchSetTagValue, setBatchSetTagValue] = useState('');
const removeRecord = (record) => {
@@ -968,6 +976,29 @@ const ChannelsTable = () => {
}
};
const batchSetChannelTag = async () => {
if (selectedChannels.length === 0) {
showError(t('请先选择要设置标签的渠道!'));
return;
}
if (batchSetTagValue === '') {
showError(t('标签不能为空!'));
return;
}
let ids = selectedChannels.map(channel => channel.id);
const res = await API.post('/api/channel/batch/tag', {
ids: ids,
tag: batchSetTagValue === '' ? null : batchSetTagValue
});
if (res.data.success) {
showSuccess(t('已为 ${count} 个渠道设置标签!').replace('${count}', res.data.data));
await refresh();
setShowBatchSetTag(false);
} else {
showError(res.data.message);
}
};
return (
<>
<EditTagModal
@@ -1115,11 +1146,11 @@ const ChannelsTable = () => {
</div>
<div style={{ marginTop: 20 }}>
<Space>
<Typography.Text strong>{t('开启批量删除')}</Typography.Text>
<Typography.Text strong>{t('开启批量操作')}</Typography.Text>
<Switch
label={t('开启批量删除')}
label={t('开启批量操作')}
uncheckedText={t('关')}
aria-label={t('是否开启批量删除')}
aria-label={t('是否开启批量操作')}
onChange={(v) => {
setEnableBatchDelete(v);
}}
@@ -1167,7 +1198,17 @@ const ChannelsTable = () => {
loadChannels(0, pageSize, idSort, v);
}}
/>
<Button
disabled={!enableBatchDelete}
theme="light"
type="primary"
style={{ marginRight: 8 }}
onClick={() => setShowBatchSetTag(true)}
>
{t('批量设置标签')}
</Button>
</Space>
</div>
@@ -1201,6 +1242,23 @@ const ChannelsTable = () => {
: null
}
/>
<Modal
title={t('批量设置标签')}
visible={showBatchSetTag}
onOk={batchSetChannelTag}
onCancel={() => setShowBatchSetTag(false)}
maskClosable={false}
centered={true}
>
<div style={{ marginBottom: 20 }}>
<Typography.Text>{t('请输入要设置的标签名称')}</Typography.Text>
</div>
<Input
placeholder={t('请输入标签名称')}
value={batchSetTagValue}
onChange={(v) => setBatchSetTagValue(v)}
/>
</Modal>
</>
);
};

View File

@@ -406,7 +406,7 @@ const UsersTable = () => {
if (searchKeyword === '') {
await loadUsers(activePage - 1);
} else {
await searchUsers();
await searchUsers(searchKeyword, searchGroup);
}
};

View File

@@ -59,6 +59,9 @@ export function renderNumber(num) {
}
export function renderQuotaNumberWithDigit(num, digits = 2) {
if (typeof num !== 'number' || isNaN(num)) {
return 0;
}
let displayInCurrency = localStorage.getItem('display_in_currency');
num = num.toFixed(digits);
if (displayInCurrency) {

View File

@@ -180,6 +180,9 @@ export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') {
let month = (date.getMonth() + 1).toString();
let day = date.getDate().toString();
let hour = date.getHours().toString();
if (day === '24') {
console.log("timestamp", timestamp);
}
if (month.length === 1) {
month = '0' + month;
}

View File

@@ -546,8 +546,8 @@
"是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection",
"开启批量操作": "Enable batch selection",
"是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",

View File

@@ -548,8 +548,8 @@
"是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection",
"开启批量操作": "Enable batch selection",
"是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",
@@ -1237,5 +1237,8 @@
"更多": "Expand more",
"个模型": "models",
"可用模型": "Available models",
"时间范围": "Time range"
"时间范围": "Time range",
"批量设置标签": "Batch set tag",
"请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name"
}

View File

@@ -143,8 +143,7 @@ const Detail = (props) => {
content: [
{
key: (datum) => datum['Model'],
value: (datum) =>
renderQuotaNumberWithDigit(parseFloat(datum['Usage']), 4),
value: (datum) => renderQuota(datum['rawQuota'] || 0, 4),
},
],
},
@@ -152,22 +151,28 @@ const Detail = (props) => {
content: [
{
key: (datum) => datum['Model'],
value: (datum) => datum['Usage'],
value: (datum) => datum['rawQuota'] || 0,
},
],
updateContent: (array) => {
array.sort((a, b) => b.value - a.value);
let sum = 0;
for (let i = 0; i < array.length; i++) {
sum += parseFloat(array[i].value);
array[i].value = renderQuotaNumberWithDigit(
parseFloat(array[i].value),
4,
);
if (array[i].key == "其他") {
continue;
}
let value = parseFloat(array[i].value);
if (isNaN(value)) {
value = 0;
}
if (array[i].datum && array[i].datum.TimeSum) {
sum = array[i].datum.TimeSum;
}
array[i].value = renderQuota(value, 4);
}
array.unshift({
key: t('总计'),
value: renderQuotaNumberWithDigit(sum, 4),
value: renderQuota(sum, 4),
});
return array;
},
@@ -212,19 +217,8 @@ const Detail = (props) => {
created_at: now.getTime() / 1000,
});
}
// 根据dataExportDefaultTime重制时间粒度
let timeGranularity = 3600;
if (dataExportDefaultTime === 'day') {
timeGranularity = 86400;
} else if (dataExportDefaultTime === 'week') {
timeGranularity = 604800;
}
// sort created_at
data.sort((a, b) => a.created_at - b.created_at);
data.forEach((item) => {
item['created_at'] =
Math.floor(item['created_at'] / timeGranularity) * timeGranularity;
});
updateChartData(data);
} else {
showError(message);
@@ -250,14 +244,14 @@ const Detail = (props) => {
let uniqueModels = new Set();
let totalTokens = 0;
// 收集所有唯一的模型名称和时间点
let uniqueTimes = new Set();
// 收集所有唯一的模型名称
data.forEach(item => {
uniqueModels.add(item.model_name);
uniqueTimes.add(timestamp2string1(item.created_at, dataExportDefaultTime));
totalTokens += item.token_used;
totalQuota += item.quota;
totalTimes += item.count;
});
// 处理颜色映射
const newModelColors = {};
Array.from(uniqueModels).forEach((modelName) => {
@@ -267,56 +261,82 @@ const Detail = (props) => {
});
setModelColors(newModelColors);
// 处理饼图数据
for (let item of data) {
totalQuota += item.quota;
totalTimes += item.count;
let pieItem = newPieData.find((it) => it.type === item.model_name);
if (pieItem) {
pieItem.value += item.count;
} else {
newPieData.push({
type: item.model_name,
value: item.count,
// 按时间和模型聚合数据
let aggregatedData = new Map();
data.forEach(item => {
const timeKey = timestamp2string1(item.created_at, dataExportDefaultTime);
const modelKey = item.model_name;
const key = `${timeKey}-${modelKey}`;
if (!aggregatedData.has(key)) {
aggregatedData.set(key, {
time: timeKey,
model: modelKey,
quota: 0,
count: 0
});
}
const existing = aggregatedData.get(key);
existing.quota += item.quota;
existing.count += item.count;
});
// 处理饼图数据
let modelTotals = new Map();
for (let [_, value] of aggregatedData) {
if (!modelTotals.has(value.model)) {
modelTotals.set(value.model, 0);
}
modelTotals.set(value.model, modelTotals.get(value.model) + value.count);
}
// 处理柱状图数据
let timePoints = Array.from(uniqueTimes);
newPieData = Array.from(modelTotals).map(([model, count]) => ({
type: model,
value: count
}));
// 生成时间点序列
let timePoints = Array.from(new Set([...aggregatedData.values()].map(d => d.time)));
if (timePoints.length < 7) {
// 根据时间粒度生成合适的时间点
const generateTimePoints = () => {
let lastTime = Math.max(...data.map(item => item.created_at));
let points = [];
let interval = dataExportDefaultTime === 'hour' ? 3600
const lastTime = Math.max(...data.map(item => item.created_at));
const interval = dataExportDefaultTime === 'hour' ? 3600
: dataExportDefaultTime === 'day' ? 86400
: 604800;
for (let i = 0; i < 7; i++) {
points.push(timestamp2string1(lastTime - (i * interval), dataExportDefaultTime));
}
return points.reverse();
};
timePoints = generateTimePoints();
timePoints = Array.from({length: 7}, (_, i) =>
timestamp2string1(lastTime - (6-i) * interval, dataExportDefaultTime)
);
}
// 为每个时间点和模型生成数据
// 生成柱状图数据
timePoints.forEach(time => {
Array.from(uniqueModels).forEach(model => {
let existingData = data.find(item =>
timestamp2string1(item.created_at, dataExportDefaultTime) === time &&
item.model_name === model
);
newLineData.push({
// 为每个时间点收集所有模型的数据
let timeData = Array.from(uniqueModels).map(model => {
const key = `${time}-${model}`;
const aggregated = aggregatedData.get(key);
return {
Time: time,
Model: model,
Usage: existingData ? parseFloat(getQuotaWithUnit(existingData.quota)) : 0
});
rawQuota: aggregated?.quota || 0,
Usage: aggregated?.quota ? getQuotaWithUnit(aggregated.quota, 4) : 0
};
});
// 计算该时间点的总计
const timeSum = timeData.reduce((sum, item) => sum + item.rawQuota, 0);
// 按照 rawQuota 从大到小排序
timeData.sort((a, b) => b.rawQuota - a.rawQuota);
// 为每个数据点添加该时间的总计
timeData = timeData.map(item => ({
...item,
TimeSum: timeSum
}));
// 将排序后的数据添加到 newLineData
newLineData.push(...timeData);
});
// 排序

View File

@@ -390,7 +390,7 @@ const EditToken = (props) => {
setUnlimitedQuota();
}}
>
{unlimited_quota ? t('取消<EFBFBD><EFBFBD><EFBFBD>限额度') : t('设为无限额度')}
{unlimited_quota ? t('取消限额度') : t('设为无限额度')}
</Button>
</div>
<Divider />