Compare commits

...

61 Commits

Author SHA1 Message Date
CalciumIon
52c023a1dd fix #663 2024-12-27 21:59:05 +08:00
CalciumIon
1cef91a741 fix: prevent duplicate models in user group retrieval 2024-12-27 21:25:44 +08:00
CalciumIon
77861e6440 refactor: improve user group handling and add GetUserUsableGroups function
- Introduced a new function `GetUserUsableGroupsCopy` to return a copy of user usable groups.
- Updated `GetUserUsableGroups` to utilize the new function for better encapsulation.
- Changed variable names from `UserUsableGroups` to `userUsableGroups` for consistency.
- Enhanced `GetUserUsableGroups` logic to ensure it returns a copy of the groups, preventing unintended modifications.
2024-12-27 21:19:22 +08:00
CalciumIon
5f082d72bb update dockerignore 2024-12-27 20:49:58 +08:00
CalciumIon
0fd0e5d309 fix: oauth bind 2024-12-27 18:32:11 +08:00
CalciumIon
d2297d2723 feat: update o1 default token encoder 2024-12-27 15:03:10 +08:00
CalciumIon
62ae46b552 feat: support azure stream_options 2024-12-26 22:51:06 +08:00
CalciumIon
0b1354ed51 update model ratio 2024-12-26 16:03:22 +08:00
Calcium-Ion
132c71390c Merge pull request #661 from tenacioustommy/fix-title-schema
fix delete title schema
2024-12-26 14:27:07 +08:00
Calcium-Ion
bb3deb7b93 Merge pull request #662 from xqx333/main
fix 重试过程多次获取图片
2024-12-26 14:26:50 +08:00
CalciumIon
f92d96e298 fix: update render function for quota display in Detail page 2024-12-26 14:25:44 +08:00
xqx333
c86762b656 Update relay-text.go
在上下文中存入promptTokens,避免重试过程重复计算
2024-12-26 02:00:04 +08:00
tenacious
3409d7a6b6 fix delete title schema 2024-12-26 00:24:45 +08:00
CalciumIon
bfba4866a5 fix: validate number input in renderQuotaNumberWithDigit and improve data handling in Detail page
- Added input validation to ensure that the `num` parameter in `renderQuotaNumberWithDigit` is a valid number, returning 0 for invalid inputs.
- Updated the `Detail` component to use `datum['rawQuota']` instead of `datum['Usage']` for rendering quota values, ensuring more accurate data representation.
- Enhanced data aggregation logic to handle cases where quota values may be missing or invalid, improving overall data integrity in charts and tables.
- Removed unnecessary time granularity calculations and streamlined the data processing for better performance.
2024-12-25 23:16:35 +08:00
CalciumIon
4fc1fe318e refactor: migrate group ratio and user usable groups logic to new setting package
- Replaced references to common.GroupRatio and common.UserUsableGroups with corresponding functions from the new setting package across multiple controllers and services.
- Introduced new setting functions for managing group ratios and user usable groups, enhancing code organization and maintainability.
- Updated related functions to ensure consistent behavior with the new setting package integration.
2024-12-25 19:31:12 +08:00
CalciumIon
b3576f24ef fix typo 2024-12-25 18:44:45 +08:00
CalciumIon
ed4d26fc9e fix: update MaxCompletionTokens for model prefix handling in buildTestRequest function 2024-12-25 17:55:20 +08:00
CalciumIon
ba56e2e8ca fix: correct user retrieval in GetPricing function 2024-12-25 14:29:52 +08:00
CalciumIon
7c20e6d047 fix: resolve pricing calculation issue (#659) 2024-12-25 14:26:43 +08:00
CalciumIon
72d6898eb5 feat: Implement batch tagging functionality for channels
- Added a new endpoint to batch set tags for multiple channels, allowing users to update tags efficiently.
- Introduced a new `BatchSetChannelTag` function in the controller to handle incoming requests and validate parameters.
- Updated the `BatchSetChannelTag` method in the model to manage database transactions and ensure data integrity during tag updates.
- Enhanced the ChannelsTable component in the frontend to support batch tag setting, including UI elements for user interaction.
- Updated localization files to include new translation keys related to batch operations and tag settings.
2024-12-25 14:19:00 +08:00
CalciumIon
f2c9388139 fix: update searchUsers function to include searchKeyword and searchGroup parameters 2024-12-25 13:44:55 +08:00
Calcium-Ion
aaf5cecefd Merge pull request #656 from Yan-Zero/main
fix: gemini function call
2024-12-25 13:38:34 +08:00
CalciumIon
fe2165ace6 fix: #657 2024-12-24 22:30:05 +08:00
CalciumIon
3003d12a20 fix: get upstream models 2024-12-24 20:48:21 +08:00
Yan
a8a2195ab1 Merge branch 'Calcium-Ion:main' into main 2024-12-24 20:46:16 +08:00
Yan
d40e6ec25d fix: gemini func call 2024-12-24 20:46:02 +08:00
CalciumIon
8129aa76f9 feat: Enhance pricing functionality with user group support
- Updated the GetPricing function in the backend to include user group information, allowing for dynamic adjustment of group ratios based on the user's group.
- Implemented logic to filter group ratios based on the user's usable groups, improving the accuracy of pricing data returned.
- Modified the ModelPricing component to utilize the new usable group data, ensuring only relevant groups are displayed in the UI.
- Enhanced state management in the frontend to accommodate the new usable group information, improving user experience and data consistency.
2024-12-24 19:23:29 +08:00
CalciumIon
fb8595da18 feat: Update localization and enhance token editing functionality
- Added new translation keys for English localization in `en.json`, including "Token group, default is the your's group" and "IP whitelist (do not overly trust this function)".
- Refactored `EditToken.js` to utilize the `useTranslation` hook for improved internationalization, ensuring all user-facing strings are translatable.
- Updated error and success messages to use translation functions, enhancing user experience for non-English speakers.
- Improved UI elements to support localization, including labels, placeholders, and button texts, ensuring consistency across the token editing interface.
2024-12-24 18:40:18 +08:00
CalciumIon
93cda60d44 feat: Add FetchModels endpoint and refactor FetchUpstreamModels
- Introduced a new `FetchModels` endpoint to retrieve model IDs from a specified base URL and API key, enhancing flexibility for different channel types.
- Refactored `FetchUpstreamModels` to simplify base URL handling and improve error messages during response parsing.
- Updated API routes to include the new endpoint and adjusted the frontend to utilize the new fetch mechanism for model lists.
- Removed outdated checks for channel type in the frontend, streamlining the model fetching process.
2024-12-24 18:02:08 +08:00
CalciumIon
2ec5eafbce feat: Enhance LogsTable component with mobile support and date handling improvements
- Added mobile-specific date pickers for start and end timestamps in the LogsTable component, improving user experience on mobile devices.
- Updated the input handling for date values to ensure valid date formats are maintained.
- Introduced a new translation key for "时间范围" (Time range) in the English locale file to support localization efforts.
2024-12-24 15:44:11 +08:00
CalciumIon
be0c240e97 Merge remote-tracking branch 'origin/main' 2024-12-24 14:48:43 +08:00
CalciumIon
7180e6f114 feat: Enhance logging functionality with group support
- Added a new 'group' parameter to various logging functions, including RecordConsumeLog, GetAllLogs, and GetUserLogs, to allow for more granular log tracking.
- Updated the logs table component to display group information, improving the visibility of log data.
- Refactored related functions to accommodate the new group parameter, ensuring consistent handling across the application.
- Improved the initialization of the group column for PostgreSQL compatibility.
2024-12-24 14:48:11 +08:00
Calcium-Ion
61495a460a Merge pull request #652 from Yan-Zero/main
fix: mutil func call in gemini
2024-12-23 20:50:31 +08:00
CalciumIon
cf3287a10a Merge remote-tracking branch 'origin/main' 2024-12-23 20:48:31 +08:00
CalciumIon
f3f1817aea feat: Add request start time context key and update middleware
- Introduced a new constant `ContextKeyRequestStartTime` to store the request start time in the context, enhancing request tracking.
- Updated the `Distribute` middleware to set the request start time in the context using the new constant.
- Modified the `GenRelayInfo` function to retrieve the request start time from the context, ensuring accurate timing information is used in relay operations.
2024-12-23 20:48:10 +08:00
Yan
a4795737fe fix: mutil func call in gemini 2024-12-23 01:26:14 +08:00
Calcium-Ion
eec8f523ce Merge pull request #651 from tenacioustommy/fix-gemini-json
fix-gemini-json-schema
2024-12-23 00:04:08 +08:00
CalciumIon
58fac129d6 feat: Enhance GeminiChatHandler to include RelayInfo
- Updated the GeminiChatHandler function to accept an additional parameter, RelayInfo, allowing for better context handling during chat operations.
- Modified the DoResponse method in the Adaptor to pass RelayInfo to GeminiChatHandler, ensuring consistent usage of upstream model information.
- Enhanced the GeminiChatStreamHandler to utilize the upstream model name from RelayInfo, improving response accuracy and data representation in Gemini requests.
2024-12-23 00:02:15 +08:00
CalciumIon
241c9389ef refactor: Remove unused context and logging in CovertGemini2OpenAI function
- Eliminated the unused `context` import and the logging of `geminiRequest` in the `CovertGemini2OpenAI` function, improving code cleanliness and reducing unnecessary overhead.
- This change enhances the maintainability of the code by removing redundant elements that do not contribute to functionality.
2024-12-22 23:54:11 +08:00
CalciumIon
1d0ef89ce9 feat: Add FunctionResponse type and enhance GeminiPart structure
- Introduced a new `FunctionResponse` type to encapsulate function call responses, improving the clarity of data handling.
- Updated the `GeminiPart` struct to include the new `FunctionResponse` field, allowing for better representation of function call results in Gemini requests.
- Modified the `CovertGemini2OpenAI` function to handle tool calls more effectively by setting the message role and appending function responses to the Gemini parts, enhancing the integration with OpenAI and Gemini systems.
2024-12-22 23:53:25 +08:00
tenacious
cce2990db6 fix-gemini-json 2024-12-22 23:48:09 +08:00
CalciumIon
a7e1d17c3e feat: Introduce settings package and refactor constants
- Added a new `setting` package to replace the `constant` package for configuration management, improving code organization and clarity.
- Moved various configuration variables such as `ServerAddress`, `PayAddress`, and `SensitiveWords` to the new `setting` package.
- Updated references throughout the codebase to use the new `setting` package, ensuring consistent access to configuration values.
- Introduced new files for managing chat settings and midjourney settings, enhancing modularity and maintainability of the code.
2024-12-22 17:24:29 +08:00
CalciumIon
c4e256e69b refactor: Update Message methods to use pointer receivers
- Refactored ParseToolCalls, SetToolCalls, IsStringContent, and ParseContent methods in the Message struct to use pointer receivers, improving efficiency and consistency in handling mutable state.
- Enhanced code readability and maintainability by ensuring all relevant methods operate on the pointer receiver, aligning with Go best practices.
2024-12-22 16:30:18 +08:00
CalciumIon
87a5e40daf refactor: Update SetToolCalls method to use pointer receiver
- Changed the SetToolCalls method to use a pointer receiver for the Message struct, allowing for modifications to the original instance.
- This change improves the method's efficiency and aligns with Go best practices for mutating struct methods.
2024-12-22 16:22:55 +08:00
CalciumIon
0c326556aa refactor: Update OpenAI request and message handling
- Changed the type of ToolCalls in the Message struct from `any` to `json.RawMessage` for better type safety and clarity.
- Introduced ParseToolCalls and SetToolCalls methods to handle ToolCalls more effectively, improving code readability and maintainability.
- Updated the ParseContent method to work with the new MediaContent type instead of MediaMessage, enhancing the structure of content parsing.
- Refactored Gemini relay functions to utilize the new ToolCalls handling methods, streamlining the integration with OpenAI and Gemini systems.
2024-12-22 16:20:30 +08:00
Calcium-Ion
794f6a6e34 Merge pull request #648 from palboss/main
解决  #534 用户管理-管理用户-查询报错 (SQLSTATE 42601)-postgresql
2024-12-22 14:37:34 +08:00
CalciumIon
656e809202 refactor: Simplify Gemini function parameter handling
- Removed redundant checks for non-empty properties in function parameters.
- Set function parameters to nil when no properties are needed, streamlining the logic for handling Gemini requests.
- Improved code clarity and maintainability by eliminating unnecessary complexity.
2024-12-22 14:35:21 +08:00
CalciumIon
53ab2aaee4 feat: Enhance Gemini function parameter handling
- Added logic to ensure that function parameters have non-empty properties.
- Implemented checks to add a default empty property if no parameters are needed.
- Updated the required field to match existing properties, improving the robustness of the Gemini function integration.
2024-12-22 14:29:14 +08:00
borland
a02bc3342f Update user.go 2024-12-22 00:03:00 +08:00
borland
f54d0cb3b0 Update user.go 2024-12-22 00:02:28 +08:00
CalciumIon
a5c48c2772 feat: Enhance LogsTable to render group information
- Added `renderGroup` function to improve the display of log data by rendering the 'group' information in the LogsTable component.
- Updated the rendering logic to utilize the new function, enhancing the UI's clarity and usability for grouped logs.
2024-12-21 20:28:26 +08:00
CalciumIon
cffaf0d636 feat: Add log information generation and enhance LogsTable component
- Introduced `log_info_generate.go` to implement functions for generating various log information, including text, WebSocket, and audio details.
- Enhanced `LogsTable` component to display the 'group' information from the log data, improving the visibility of grouped logs in the UI.
2024-12-21 20:24:22 +08:00
CalciumIon
865b98a454 Merge branch 'feat/o1'
# Conflicts:
#	dto/openai_request.go
2024-12-21 16:45:45 +08:00
Calcium-Ion
5bdbf3a673 Merge pull request #645 from MartialBE/gemini_res_format
支持gemini结构化输出
2024-12-21 16:40:51 +08:00
MartialBE
43a7b59b68 feat: support for Gemini structured output. 2024-12-21 16:01:17 +08:00
HynoR
eac3463401 Merge remote-tracking branch 'origin/feat/o1' into feat/o1 2024-12-20 23:14:20 +08:00
HynoR
1fa478af20 feat: 适配o1模型 2024-12-20 23:14:10 +08:00
TAKO
3b58b4989d Merge branch 'Calcium-Ion:main' into feat/o1 2024-12-20 22:12:26 +08:00
HynoR
0b1ba2eeb9 feat: 适配o1模型 2024-12-20 22:09:02 +08:00
HynoR
35277f2b4a feat: 适配o1模型 2024-12-20 22:07:53 +08:00
CalciumIon
03256dbdad refactor: Enhance error handling in Gemini request conversion
- Updated `CovertGemini2OpenAI` function to return an error alongside the GeminiChatRequest, improving error reporting for image processing.
- Modified `ConvertRequest` methods in both `adaptor.go` files to handle potential errors from the Gemini conversion, ensuring robust request handling.
- Improved clarity and maintainability of the code by explicitly managing error cases during request conversion.
2024-12-20 21:50:58 +08:00
69 changed files with 1383 additions and 655 deletions

6
.dockerignore Normal file
View File

@@ -0,0 +1,6 @@
.github
.git
*.md
.vscode
.gitignore
Makefile

2
.gitignore vendored
View File

@@ -8,3 +8,5 @@ build
logs
web/dist
.env
one-api
.DS_Store

View File

@@ -46,6 +46,8 @@ var defaultModelRatio = map[string]float64{
"gpt-4o-2024-08-06": 1.25, // $2.5 / 1M tokens
"gpt-4o-2024-11-20": 1.25, // $2.5 / 1M tokens
"gpt-4o-realtime-preview": 2.5,
"o1": 7.5,
"o1-2024-12-17": 7.5,
"o1-preview": 7.5,
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5,
@@ -354,7 +356,7 @@ func GetCompletionRatio(name string) float64 {
}
return 2
}
if strings.HasPrefix(name, "o1-") {
if strings.HasPrefix(name, "o1") {
return 4
}
if name == "chatgpt-4o-latest" {

View File

@@ -1,46 +0,0 @@
package common
import (
"encoding/json"
)
var UserUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(UserUsableGroups)
if err != nil {
SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
UserUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &UserUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
if userGroup == "" {
// 如果userGroup为空返回UserUsableGroups
return UserUsableGroups
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := UserUsableGroups[userGroup]; !ok {
appendUserUsableGroups := make(map[string]string)
for k, v := range UserUsableGroups {
appendUserUsableGroups[k] = v
}
appendUserUsableGroups[userGroup] = "用户分组"
return appendUserUsableGroups
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return UserUsableGroups
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := UserUsableGroups[groupName]
return ok
}

View File

@@ -4,7 +4,6 @@ import (
crand "crypto/rand"
"encoding/base64"
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/big"
@@ -15,6 +14,8 @@ import (
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
func OpenBrowser(url string) {

5
constant/context_key.go Normal file
View File

@@ -0,0 +1,5 @@
package constant
const (
ContextKeyRequestStartTime = "request_start_time"
)

View File

@@ -1,6 +1,9 @@
package constant
var (
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonLength = "length"
FinishReasonFunctionCall = "function_call"
FinishReasonContentFilter = "content_filter"
)

View File

@@ -1,11 +1,5 @@
package constant
var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true
const (
MjErrorUnknown = 5
MjRequestError = 4

View File

@@ -141,7 +141,8 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
milliseconds := tok.Sub(tik).Milliseconds()
consumedTime := float64(milliseconds) / 1000.0
other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other)
model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试",
quota, "模型测试", 0, quota, int(consumedTime), false, "default", other)
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
@@ -151,8 +152,8 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
Model: "", // this will be set later
Stream: false,
}
if strings.HasPrefix(model, "o1-") {
testRequest.MaxCompletionTokens = 1
if strings.HasPrefix(model, "o1") {
testRequest.MaxCompletionTokens = 10
} else if strings.HasPrefix(model, "gemini-2.0-flash-thinking") {
testRequest.MaxTokens = 2
} else {

View File

@@ -97,6 +97,7 @@ func FetchUpstreamModels(c *gin.Context) {
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -105,34 +106,35 @@ func FetchUpstreamModels(c *gin.Context) {
})
return
}
if channel.Type != common.ChannelTypeOpenAI {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "仅支持 OpenAI 类型渠道",
})
return
//if channel.Type != common.ChannelTypeOpenAI {
// c.JSON(http.StatusOK, gin.H{
// "success": false,
// "message": "仅支持 OpenAI 类型渠道",
// })
// return
//}
baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
url := fmt.Sprintf("%s/v1/models", *channel.BaseURL)
url := fmt.Sprintf("%s/v1/models", baseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
result := OpenAIModelsResponse{}
err = json.Unmarshal(body, &result)
if err != nil {
var result OpenAIModelsResponse
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
}
if !result.Success {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "上游返回错误",
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
var ids []string
@@ -417,7 +419,8 @@ func EditTagChannels(c *gin.Context) {
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
@@ -492,3 +495,105 @@ func UpdateChannel(c *gin.Context) {
})
return
}
func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Key string `json:"key"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request",
})
return
}
baseURL := req.BaseURL
if baseURL == "" {
baseURL = "https://api.openai.com"
}
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
request, err := http.NewRequest("GET", url, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
request.Header.Set("Authorization", "Bearer "+req.Key)
response, err := client.Do(request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//check status code
if response.StatusCode != http.StatusOK {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "Failed to fetch models",
})
return
}
defer response.Body.Close()
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var models []string
for _, model := range result.Data {
models = append(models, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}

View File

@@ -3,13 +3,13 @@ package controller
import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
for groupName, _ := range common.GroupRatio {
for groupName, _ := range setting.GetGroupRatioCopy() {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
@@ -24,9 +24,9 @@ func GetUserGroups(c *gin.Context) {
userGroup := ""
userId := c.GetInt("id")
userGroup, _ = model.CacheGetUserGroup(userId)
for groupName, _ := range common.GroupRatio {
for groupName, _ := range setting.GetGroupRatioCopy() {
// UserUsableGroups contains the groups that the user can use
userUsableGroups := common.GetUserUsableGroups(userGroup)
userUsableGroups := setting.GetUserUsableGroups(userGroup)
if _, ok := userUsableGroups[groupName]; ok {
usableGroups[groupName] = userUsableGroups[groupName]
}

View File

@@ -25,7 +25,8 @@ func GetAllLogs(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel)
group := c.Query("group")
logs, total, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, (p-1)*pageSize, pageSize, channel, group)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -63,7 +64,8 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize)
group := c.Query("group")
logs, total, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, (p-1)*pageSize, pageSize, group)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -146,7 +148,8 @@ func GetLogsStat(c *gin.Context) {
username := c.Query("username")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
group := c.Query("group")
stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(http.StatusOK, gin.H{
"success": true,
@@ -168,7 +171,8 @@ func GetLogsSelfStat(c *gin.Context) {
tokenName := c.Query("token_name")
modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel"))
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
group := c.Query("group")
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel, group)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(200, gin.H{
"success": true,

View File

@@ -10,10 +10,10 @@ import (
"log"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
"one-api/service"
"one-api/setting"
"strconv"
"time"
)
@@ -231,9 +231,9 @@ func GetAllMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
if setting.MjForwardUrlEnabled {
for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}
@@ -263,9 +263,9 @@ func GetUserMidjourney(c *gin.Context) {
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
if constant.MjForwardUrlEnabled {
if setting.MjForwardUrlEnabled {
for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney
}
}

View File

@@ -5,8 +5,8 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/setting"
"strings"
"github.com/gin-gonic/gin"
@@ -47,9 +47,9 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled,
"server_address": constant.ServerAddress,
"price": constant.Price,
"min_topup": constant.MinTopUp,
"server_address": setting.ServerAddress,
"price": setting.Price,
"min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
@@ -63,9 +63,9 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "",
"mj_notify_enabled": constant.MjNotifyEnabled,
"chats": constant.Chats,
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": setting.MjNotifyEnabled,
"chats": setting.Chats,
},
})
return
@@ -207,7 +207,7 @@ func SendPasswordResetEmail(c *gin.Context) {
}
code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
"strings"
"github.com/gin-gonic/gin"
@@ -83,7 +84,7 @@ func UpdateOption(c *gin.Context) {
return
}
case "GroupRatio":
err = common.CheckGroupRatio(option.Value)
err = setting.CheckGroupRatio(option.Value)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,

View File

@@ -4,14 +4,38 @@ import (
"github.com/gin-gonic/gin"
"one-api/common"
"one-api/model"
"one-api/setting"
)
func GetPricing(c *gin.Context) {
pricing := model.GetPricing()
userId, exists := c.Get("id")
usableGroup := map[string]string{}
groupRatio := map[string]float64{}
for s, f := range setting.GetGroupRatioCopy() {
groupRatio[s] = f
}
var group string
if exists {
user, err := model.GetUserById(userId.(int), false)
if err == nil {
group = user.Group
}
}
usableGroup = setting.GetUserUsableGroups(group)
// check groupRatio contains usableGroup
for group := range setting.GetGroupRatioCopy() {
if _, ok := usableGroup[group]; !ok {
delete(groupRatio, group)
}
}
c.JSON(200, gin.H{
"success": true,
"data": pricing,
"group_ratio": common.GroupRatio,
"success": true,
"data": pricing,
"group_ratio": groupRatio,
"usable_group": usableGroup,
})
}

View File

@@ -17,6 +17,7 @@ import (
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"strings"
)
@@ -83,7 +84,7 @@ func Playground(c *gin.Context) {
if group == "" {
group = userGroup
} else {
if !common.GroupInUserUsableGroups(group) && group != userGroup {
if !setting.GroupInUserUsableGroups(group) && group != userGroup {
openaiErr = service.OpenAIErrorWrapperLocal(errors.New("无权访问该分组"), "group_not_allowed", http.StatusForbidden)
return
}

View File

@@ -8,9 +8,9 @@ import (
"log"
"net/url"
"one-api/common"
"one-api/constant"
"one-api/model"
"one-api/service"
"one-api/setting"
"strconv"
"sync"
"time"
@@ -28,13 +28,13 @@ type AmountRequest struct {
}
func GetEpayClient() *epay.Client {
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" {
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
return nil
}
withUrl, err := epay.NewClient(&epay.Config{
PartnerID: constant.EpayId,
Key: constant.EpayKey,
}, constant.PayAddress)
PartnerID: setting.EpayId,
Key: setting.EpayKey,
}, setting.PayAddress)
if err != nil {
return nil
}
@@ -50,12 +50,12 @@ func getPayMoney(amount float64, group string) float64 {
if topupGroupRatio == 0 {
topupGroupRatio = 1
}
payMoney := amount * constant.Price * topupGroupRatio
payMoney := amount * setting.Price * topupGroupRatio
return payMoney
}
func getMinTopup() int {
minTopup := constant.MinTopUp
minTopup := setting.MinTopUp
if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit)
}
@@ -94,7 +94,7 @@ func RequestEpay(c *gin.Context) {
payType = "wxpay"
}
callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/setting"
"strconv"
"strings"
"sync"
@@ -454,7 +455,15 @@ func GetUserModels(c *gin.Context) {
})
return
}
models := model.GetGroupModels(user.Group)
groups := setting.GetUserUsableGroups(user.Group)
var models []string
for group := range groups {
for _, g := range model.GetGroupModels(group) {
if !common.StringsContains(models, g) {
models = append(models, g)
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",

View File

@@ -3,39 +3,48 @@ package dto
import "encoding/json"
type ResponseFormat struct {
Type string `json:"type,omitempty"`
Type string `json:"type,omitempty"`
JsonSchema *FormatJsonSchema `json:"json_schema,omitempty"`
}
type FormatJsonSchema struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Schema any `json:"schema,omitempty"`
Strict any `json:"strict,omitempty"`
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
ReasoningEffort string `json:"reasoning_effort,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
EncodingFormat any `json:"encoding_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"`
TopLogProbs int `json:"top_logprobs,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Modalities any `json:"modalities,omitempty"`
Audio any `json:"audio,omitempty"`
}
type OpenAITools struct {
@@ -80,11 +89,11 @@ type Message struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
}
type MediaMessage struct {
type MediaContent struct {
Type string `json:"type"`
Text string `json:"text"`
ImageUrl any `json:"image_url,omitempty"`
@@ -107,7 +116,23 @@ const (
ContentTypeInputAudio = "input_audio"
)
func (m Message) StringContent() string {
func (m *Message) ParseToolCalls() []ToolCall {
if m.ToolCalls == nil {
return nil
}
var toolCalls []ToolCall
if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
return toolCalls
}
return toolCalls
}
func (m *Message) SetToolCalls(toolCalls any) {
toolCallsJson, _ := json.Marshal(toolCalls)
m.ToolCalls = toolCallsJson
}
func (m *Message) StringContent() string {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return stringContent
@@ -120,7 +145,7 @@ func (m *Message) SetStringContent(content string) {
m.Content = jsonContent
}
func (m Message) IsStringContent() bool {
func (m *Message) IsStringContent() bool {
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
return true
@@ -128,11 +153,11 @@ func (m Message) IsStringContent() bool {
return false
}
func (m Message) ParseContent() []MediaMessage {
var contentList []MediaMessage
func (m *Message) ParseContent() []MediaContent {
var contentList []MediaContent
var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil {
contentList = append(contentList, MediaMessage{
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: stringContent,
})
@@ -148,7 +173,7 @@ func (m Message) ParseContent() []MediaMessage {
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, MediaMessage{
contentList = append(contentList, MediaContent{
Type: ContentTypeText,
Text: subStr,
})
@@ -161,7 +186,7 @@ func (m Message) ParseContent() []MediaMessage {
} else {
subObj["detail"] = "high"
}
contentList = append(contentList, MediaMessage{
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: subObj["url"].(string),
@@ -169,7 +194,7 @@ func (m Message) ParseContent() []MediaMessage {
},
})
} else if url, ok := contentMap["image_url"].(string); ok {
contentList = append(contentList, MediaMessage{
contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: url,
@@ -179,7 +204,7 @@ func (m Message) ParseContent() []MediaMessage {
}
case ContentTypeInputAudio:
if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
contentList = append(contentList, MediaMessage{
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: subObj["data"].(string),

View File

@@ -10,8 +10,10 @@ import (
"one-api/model"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
@@ -42,12 +44,12 @@ func Distribute() func(c *gin.Context) {
tokenGroup := c.GetString("token_group")
if tokenGroup != "" {
// check common.UserUsableGroups[userGroup]
if _, ok := common.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
if _, ok := setting.GetUserUsableGroups(userGroup)[tokenGroup]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("令牌分组 %s 已被禁用", tokenGroup))
return
}
// check group in common.GroupRatio
if _, ok := common.GroupRatio[tokenGroup]; !ok {
if !setting.ContainsGroupRatio(tokenGroup) {
abortWithOpenAiMessage(c, http.StatusForbidden, fmt.Sprintf("分组 %s 已被弃用", tokenGroup))
return
}
@@ -112,6 +114,7 @@ func Distribute() func(c *gin.Context) {
}
}
}
c.Set(constant.ContextKeyRequestStartTime, time.Now())
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}

View File

@@ -3,10 +3,11 @@ package model
import (
"errors"
"fmt"
"github.com/samber/lo"
"gorm.io/gorm"
"one-api/common"
"strings"
"github.com/samber/lo"
"gorm.io/gorm"
)
type Ability struct {
@@ -173,18 +174,67 @@ func (channel *Channel) DeleteAbilities() error {
// UpdateAbilities updates abilities of this channel.
// Make sure the channel is completed before calling this function.
func (channel *Channel) UpdateAbilities() error {
// A quick and dirty way to update abilities
func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
isNewTx := false
// 如果没有传入事务,创建新的事务
if tx == nil {
tx = DB.Begin()
if tx.Error != nil {
return tx.Error
}
isNewTx = true
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
}
// First delete all abilities of this channel
err := channel.DeleteAbilities()
err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
if err != nil {
if isNewTx {
tx.Rollback()
}
return err
}
// Then add new abilities
err = channel.AddAbilities()
if err != nil {
return err
models_ := strings.Split(channel.Models, ",")
groups_ := strings.Split(channel.Group, ",")
abilities := make([]Ability, 0, len(models_))
for _, model := range models_ {
for _, group := range groups_ {
ability := Ability{
Group: group,
Model: model,
ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority,
Weight: uint(channel.GetWeight()),
Tag: channel.Tag,
}
abilities = append(abilities, ability)
}
}
if len(abilities) > 0 {
for _, chunk := range lo.Chunk(abilities, 50) {
err = tx.Create(&chunk).Error
if err != nil {
if isNewTx {
tx.Rollback()
}
return err
}
}
}
// 如果是新创建的事务,需要提交
if isNewTx {
return tx.Commit().Error
}
return nil
}
@@ -246,7 +296,7 @@ func FixAbility() (int, error) {
return 0, err
}
for _, channel := range channels {
err := channel.UpdateAbilities()
err := channel.UpdateAbilities(nil)
if err != nil {
common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
} else {

View File

@@ -257,7 +257,7 @@ func (channel *Channel) Update() error {
return err
}
DB.Model(channel).First(channel, "id = ?", channel.Id)
err = channel.UpdateAbilities()
err = channel.UpdateAbilities(nil)
return err
}
@@ -389,7 +389,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
channels, err := GetChannelsByTag(updatedTag, false)
if err == nil {
for _, channel := range channels {
err = channel.UpdateAbilities()
err = channel.UpdateAbilities(nil)
if err != nil {
common.SysError("failed to update abilities: " + err.Error())
}
@@ -509,3 +509,42 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
}
channel.Setting = string(settingBytes)
}
func GetChannelsByIds(ids []int) ([]*Channel, error) {
var channels []*Channel
err := DB.Where("id in (?)", ids).Find(&channels).Error
return channels, err
}
func BatchSetChannelTag(ids []int, tag *string) error {
// 开启事务
tx := DB.Begin()
if tx.Error != nil {
return tx.Error
}
// 更新标签
err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
if err != nil {
tx.Rollback()
return err
}
// update ability status
channels, err := GetChannelsByIds(ids)
if err != nil {
tx.Rollback()
return err
}
for _, channel := range channels {
err = channel.UpdateAbilities(tx)
if err != nil {
tx.Rollback()
return err
}
}
// 提交事务
return tx.Commit().Error
}

View File

@@ -12,6 +12,16 @@ import (
"gorm.io/gorm"
)
var groupCol string
func init() {
if common.UsingPostgreSQL {
groupCol = `"group"`
} else {
groupCol = "`group`"
}
}
type Log struct {
Id int `json:"id" gorm:"index:idx_created_at_id,priority:1"`
UserId int `json:"user_id" gorm:"index"`
@@ -28,6 +38,7 @@ type Log struct {
IsStream bool `json:"is_stream" gorm:"default:false"`
ChannelId int `json:"channel" gorm:"index"`
TokenId int `json:"token_id" gorm:"default:0;index"`
Group string `json:"group" gorm:"index"`
Other string `json:"other"`
}
@@ -70,7 +81,9 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int, isStream bool, other map[string]interface{}) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int,
modelName string, tokenName string, quota int, content string, tokenId int, userQuota int, useTimeSeconds int,
isStream bool, group string, other map[string]interface{}) {
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, 用户调用前余额=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, userQuota, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !common.LogConsumeEnabled {
return
@@ -92,6 +105,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
TokenId: tokenId,
UseTime: useTimeSeconds,
IsStream: isStream,
Group: group,
Other: otherStr,
}
err := LOG_DB.Create(log).Error
@@ -105,7 +119,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
}
}
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, total int64, err error) {
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int, group string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB
@@ -130,6 +144,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
if group != "" {
tx = tx.Where(groupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
@@ -141,7 +158,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
return logs, total, err
}
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, total int64, err error) {
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string) (logs []*Log, total int64, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = LOG_DB.Where("user_id = ?", userId)
@@ -160,6 +177,9 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp)
}
if group != "" {
tx = tx.Where(groupCol+" = ?", group)
}
err = tx.Model(&Log{}).Count(&total).Error
if err != nil {
return nil, 0, err
@@ -193,7 +213,7 @@ type Stat struct {
Tpm int `json:"tpm"`
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
@@ -221,6 +241,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
if group != "" {
tx = tx.Where(groupCol+" = ?", group)
rpmTpmQuery = rpmTpmQuery.Where(groupCol+" = ?", group)
}
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)

View File

@@ -2,7 +2,7 @@ package model
import (
"one-api/common"
"one-api/constant"
"one-api/setting"
"strconv"
"strings"
"time"
@@ -61,16 +61,16 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = constant.WorkerUrl
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp)
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = constant.Chats2JsonString()
common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = ""
@@ -87,8 +87,8 @@ func InitOptionMap() {
common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = common.UserUsableGroups2JSONString()
common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
common.OptionMap["UserUsableGroups"] = setting.UserUsableGroups2JSONString()
common.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
@@ -98,17 +98,17 @@ func InitOptionMap() {
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase()
@@ -209,23 +209,23 @@ func updateOptionMap(key string, value string) (err error) {
case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue
setting.MjNotifyEnabled = boolValue
case "MjAccountFilterEnabled":
constant.MjAccountFilterEnabled = boolValue
setting.MjAccountFilterEnabled = boolValue
case "MjModeClearEnabled":
constant.MjModeClearEnabled = boolValue
setting.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue
setting.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled":
constant.MjActionCheckSuccessEnabled = boolValue
setting.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue
setting.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled":
constant.CheckSensitiveOnPromptEnabled = boolValue
setting.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled":
constant.StopOnSensitiveEnabled = boolValue
setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue
}
@@ -245,25 +245,25 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken":
common.SMTPToken = value
case "ServerAddress":
constant.ServerAddress = value
setting.ServerAddress = value
case "WorkerUrl":
constant.WorkerUrl = value
setting.WorkerUrl = value
case "WorkerValidKey":
constant.WorkerValidKey = value
setting.WorkerValidKey = value
case "PayAddress":
constant.PayAddress = value
setting.PayAddress = value
case "Chats":
err = constant.UpdateChatsByJsonString(value)
err = setting.UpdateChatsByJsonString(value)
case "CustomCallbackAddress":
constant.CustomCallbackAddress = value
setting.CustomCallbackAddress = value
case "EpayId":
constant.EpayId = value
setting.EpayId = value
case "EpayKey":
constant.EpayKey = value
setting.EpayKey = value
case "Price":
constant.Price, _ = strconv.ParseFloat(value, 64)
setting.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp":
constant.MinTopUp, _ = strconv.Atoi(value)
setting.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId":
@@ -313,9 +313,9 @@ func updateOptionMap(key string, value string) (err error) {
case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value)
err = setting.UpdateGroupRatioByJSONString(value)
case "UserUsableGroups":
err = common.UpdateUserUsableGroupsByJSONString(value)
err = setting.UpdateUserUsableGroupsByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "ModelPrice":
@@ -331,9 +331,9 @@ func updateOptionMap(key string, value string) (err error) {
case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords":
constant.SensitiveWordsFromString(value)
setting.SensitiveWordsFromString(value)
case "StreamCacheQueueLength":
constant.StreamCacheQueueLength, _ = strconv.Atoi(value)
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
}
return err
}

View File

@@ -5,8 +5,8 @@ import (
"fmt"
"gorm.io/gorm"
"one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common"
"one-api/setting"
"strconv"
"strings"
)
@@ -325,7 +325,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
prompt = "您的额度已用尽"
}
if email != "" {
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress)
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil {

View File

@@ -89,26 +89,31 @@ func SearchUsers(keyword string, group string) ([]*User, error) {
var users []*User
var err error
groupCol := "`group`"
if common.UsingPostgreSQL {
groupCol = `"group"`
}
// 尝试将关键字转换为整数ID
keywordInt, err := strconv.Atoi(keyword)
if err == nil {
// 如果转换成功按照ID和可选的组别搜索用户
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt)
if group != "" {
query = query.Where("`group` = ?", group) // 使用反引号包围group
query = query.Where(groupCol+" = ?", group) // 使用反引号包围group
}
err = query.Find(&users).Error
if err != nil || len(users) > 0 {
return users, err
}
}
err = nil
query := DB.Unscoped().Omit("password")
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?"
if group != "" {
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
} else {
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
}

View File

@@ -240,14 +240,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
if message.ToolCalls != nil {
for _, tc := range message.ToolCalls.([]interface{}) {
toolCallJSON, _ := json.Marshal(tc)
var toolCall dto.ToolCall
err := json.Unmarshal(toolCallJSON, &toolCall)
if err != nil {
common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
continue
}
for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
@@ -393,7 +386,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
}
choice.SetStringContent(responseText)
if len(tools) > 0 {
choice.Message.ToolCalls = tools
choice.Message.SetToolCalls(tools)
}
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)

View File

@@ -57,7 +57,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
return CovertGemini2OpenAI(*request), nil
ai, err := CovertGemini2OpenAI(*request)
if err != nil {
return nil, err
}
return ai, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -72,7 +76,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
err, usage = GeminiChatStreamHandler(c, resp, info)
} else {
err, usage = GeminiChatHandler(c, resp)
err, usage = GeminiChatHandler(c, resp, info)
}
return
}

View File

@@ -4,7 +4,7 @@ type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
Tools []GeminiChatTools `json:"tools,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"system_instruction,omitempty"`
}
@@ -18,10 +18,39 @@ type FunctionCall struct {
Arguments any `json:"args"`
}
type GeminiFunctionResponseContent struct {
Name string `json:"name"`
Content any `json:"content"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response GeminiFunctionResponseContent `json:"response"`
}
type GeminiPartExecutableCode struct {
Language string `json:"language,omitempty"`
Code string `json:"code,omitempty"`
}
type GeminiPartCodeExecutionResult struct {
Outcome string `json:"outcome,omitempty"`
Output string `json:"output,omitempty"`
}
type GeminiFileData struct {
MimeType string `json:"mimeType,omitempty"`
FileUri string `json:"fileUri,omitempty"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
FileData *GeminiFileData `json:"fileData,omitempty"`
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
}
type GeminiChatContent struct {
@@ -34,23 +63,28 @@ type GeminiChatSafetySettings struct {
Threshold string `json:"threshold"`
}
type GeminiChatTools struct {
GoogleSearch any `json:"googleSearch,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
type GeminiChatTool struct {
GoogleSearch any `json:"googleSearch,omitempty"`
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
CodeExecution any `json:"codeExecution,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`
FinishReason *string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}

View File

@@ -17,7 +17,8 @@ import (
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatRequest {
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
SafetySettings: []GeminiChatSafetySettings{
@@ -46,63 +47,139 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
Seed: int64(textRequest.Seed),
},
}
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
if tool.Function.Name == "googleSearch" {
googleSearch = true
continue
}
if tool.Function.Name == "codeExecution" {
codeExecution = true
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
if ok {
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
if len(props) == 0 {
tool.Function.Parameters = nil
}
}
}
}
functions = append(functions, tool.Function)
}
if len(functions) > 0 {
geminiRequest.Tools = []GeminiChatTools{
{
FunctionDeclarations: functions,
},
}
if codeExecution {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
CodeExecution: make(map[string]string),
})
}
if googleSearch {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTools{
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
GoogleSearch: make(map[string]string),
})
}
if len(functions) > 0 {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
FunctionDeclarations: functions,
})
}
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
} else if textRequest.Functions != nil {
geminiRequest.Tools = []GeminiChatTools{
geminiRequest.Tools = []GeminiChatTool{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
}
}
tool_call_ids := make(map[string]string)
var system_content []string
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
if message.Role == "system" {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
{
Text: message.StringContent(),
},
system_content = append(system_content, message.StringContent())
continue
} else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "user",
})
}
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
name := ""
if message.Name != nil {
name = *message.Name
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
content := common.StrToMap(message.StringContent())
functionResp := &FunctionResponse{
Name: name,
Response: GeminiFunctionResponseContent{
Name: name,
Content: content,
},
}
if content == nil {
functionResp.Response.Content = message.StringContent()
}
*parts = append(*parts, GeminiPart{
FunctionResponse: functionResp,
})
continue
}
var parts []GeminiPart
content := GeminiChatContent{
Role: message.Role,
//Parts: []GeminiPart{
// {
// Text: message.StringContent(),
// },
//},
}
// isToolCall := false
if message.ToolCalls != nil {
// message.Role = "model"
// isToolCall = true
for _, call := range message.ParseToolCalls() {
args := map[string]interface{}{}
if call.Function.Arguments != "" {
if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
}
}
toolCall := GeminiPart{
FunctionCall: &FunctionCall{
FunctionName: call.Function.Name,
Arguments: args,
},
}
parts = append(parts, toolCall)
tool_call_ids[call.ID] = call.Function.Name
}
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
if part.Text == "" {
continue
}
parts = append(parts, GeminiPart{
Text: part.Text,
})
@@ -110,7 +187,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
continue
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
@@ -125,7 +202,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
} else {
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
if err != nil {
continue
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
@@ -136,58 +213,86 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
}
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
//if content.Role == "system" {
// content.Role = "user"
// shouldAddDummyModelMessage = true
//}
geminiRequest.Contents = append(geminiRequest.Contents, content)
//
//// If a system message is the last message, we need to add a dummy model message to make gemini happy
//if shouldAddDummyModelMessage {
// geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
// Role: "model",
// Parts: []GeminiPart{
// {
// Text: "Okay",
// },
// },
// })
// shouldAddDummyModelMessage = false
//}
}
return &geminiRequest
if len(system_content) > 0 {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
{
Text: strings.Join(system_content, "\n"),
},
},
}
}
return &geminiRequest, nil
}
func (g *GeminiChatResponse) GetResponseText() string {
if g == nil {
return ""
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
v, ok := schema.(map[string]interface{})
if !ok || len(v) == 0 {
return schema
}
return ""
// 删除所有的title字段
delete(v, "title")
// 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
}
switch v["type"] {
case "object":
delete(v, "additionalProperties")
// 处理 properties
if properties, ok := v["properties"].(map[string]interface{}); ok {
for key, value := range properties {
properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
}
}
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := v[field].([]interface{}); ok {
for i, item := range nested {
nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
}
}
}
case "array":
if items, ok := v["items"].(map[string]interface{}); ok {
v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
}
}
return v
}
func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
var toolCalls []dto.ToolCall
// func (g *GeminiChatResponse) GetResponseText() string {
// if g == nil {
// return ""
// }
// if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
// return g.Candidates[0].Content.Parts[0].Text
// }
// return ""
// }
item := candidate.Content.Parts[0]
if item.FunctionCall == nil {
return toolCalls
}
func getToolCall(item *GeminiPart) *dto.ToolCall {
argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
if err != nil {
//common.SysError("getToolCalls failed: " + err.Error())
return toolCalls
//common.SysError("getToolCall failed: " + err.Error())
return nil
}
toolCall := dto.ToolCall{
return &dto.ToolCall{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionCall{
@@ -195,10 +300,32 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
Name: item.FunctionCall.FunctionName,
},
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
// var toolCalls []dto.ToolCall
// item := candidate.Content.Parts[index]
// if item.FunctionCall == nil {
// return toolCalls
// }
// argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
// if err != nil {
// //common.SysError("getToolCalls failed: " + err.Error())
// return toolCalls
// }
// toolCall := dto.ToolCall{
// ID: fmt.Sprintf("call_%s", common.GetUUID()),
// Type: "function",
// Function: dto.FunctionCall{
// Arguments: string(argsBytes),
// Name: item.FunctionCall.FunctionName,
// },
// }
// toolCalls = append(toolCalls, toolCall)
// return toolCalls
// }
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -207,9 +334,10 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
content, _ := json.Marshal("")
for i, candidate := range response.Candidates {
is_tool_call := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: i,
Index: int(candidate.Index),
Message: dto.Message{
Role: "assistant",
Content: content,
@@ -217,48 +345,116 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
FinishReason: constant.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
if candidate.Content.Parts[0].FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
choice.Message.ToolCalls = getToolCalls(&candidate)
} else {
var texts []string
for _, part := range candidate.Content.Parts {
texts = append(texts, part.Text)
var texts []string
var tool_calls []dto.ToolCall
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getToolCall(&part); call != nil {
tool_calls = append(tool_calls, *call)
}
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
if len(tool_calls) > 0 {
choice.Message.SetToolCalls(tool_calls)
is_tool_call = true
}
// 过滤掉空行
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
if candidate.FinishReason != nil {
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = constant.FinishReasonLength
default:
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if is_tool_call {
choice.FinishReason = constant.FinishReasonToolCalls
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
//choice.Delta.SetContentString(geminiResponse.GetResponseText())
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
respFirstParts := geminiResponse.Candidates[0].Content.Parts
if respFirstParts[0].FunctionCall != nil {
// function response
choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
} else {
// text response
var texts []string
for _, part := range respFirstParts {
texts = append(texts, part.Text)
}
choice.Delta.SetContentString(strings.Join(texts, "\n"))
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
is_stop := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
is_stop = true
candidate.FinishReason = nil
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: int(candidate.Index),
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
},
}
var texts []string
isTools := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = &constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = &constant.FinishReasonLength
default:
choice.FinishReason = &constant.FinishReasonContentFilter
}
}
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
isTools = true
if call := getToolCall(&part); call != nil {
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
} else {
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
choice.Delta.SetContentString(strings.Join(texts, "\n"))
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
choices = append(choices, choice)
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
response.Choices = choices
return &response, is_stop
}
func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := ""
// responseText := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
@@ -282,13 +478,11 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
continue
}
response := streamResponseGeminiChat2OpenAI(&geminiResponse)
if response == nil {
continue
}
response, is_stop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id
response.Created = createAt
responseText += response.Choices[0].Delta.GetContentString()
response.Model = info.UpstreamModelName
// responseText += response.Choices[0].Delta.GetContentString()
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
@@ -297,12 +491,17 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
if err != nil {
common.LogError(c, err.Error())
}
if is_stop {
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
}
}
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
service.ObjectData(c, response)
var response *dto.ChatCompletionsStreamResponse
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokenDetails.TextTokens = usage.CompletionTokens
if info.ShouldIncludeUsage {
response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
@@ -316,7 +515,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
return nil, usage
}
func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -342,6 +541,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWit
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,

View File

@@ -106,15 +106,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != common.ChannelTypeOpenAI {
if info.ChannelType != common.ChannelTypeOpenAI && info.ChannelType != common.ChannelTypeAzure {
request.StreamOptions = nil
}
if strings.HasPrefix(request.Model, "o1-") {
if strings.HasPrefix(request.Model, "o1") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
}
if request.Model == "o1" || request.Model == "o1-2024-12-17" {
//修改第一个Message的内容将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
}
}
return request, nil
}

View File

@@ -13,6 +13,7 @@ var ModelList = []string{
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"o1", "o1-2024-12-17",
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",

View File

@@ -135,7 +135,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
c.Set("request_model", request.Model)
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest := gemini.CovertGemini2OpenAI(*request)
geminiRequest, err := gemini.CovertGemini2OpenAI(*request)
if err != nil {
return nil, err
}
c.Set("request_model", request.Model)
return geminiRequest, nil
} else if a.RequestMode == RequestModeLlama {
@@ -167,7 +170,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
err, usage = gemini.GeminiChatHandler(c, resp)
err, usage = gemini.GeminiChatHandler(c, resp, info)
case RequestModeLlama:
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.OriginModelName)
}

View File

@@ -2,8 +2,9 @@ package common
import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/constant"
relayconstant "one-api/relay/constant"
"strings"
"time"
@@ -66,13 +67,13 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
userId := c.GetInt("id")
group := c.GetString("group")
tokenUnlimited := c.GetBool("token_unlimited_quota")
startTime := time.Now()
startTime := c.GetTime(constant.ContextKeyRequestStartTime)
// firstResponseTime = time.Now() - 1 second
apiType, _ := constant.ChannelType2APIType(channelType)
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &RelayInfo{
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,
@@ -108,7 +109,7 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
}
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
info.ChannelType == common.ChannelCloudflare {
info.ChannelType == common.ChannelCloudflare || info.ChannelType == common.ChannelTypeAzure {
info.SupportStreamOptions = true
}
return info
@@ -158,10 +159,10 @@ func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
group := c.GetString("group")
startTime := time.Now()
apiType, _ := constant.ChannelType2APIType(channelType)
apiType, _ := relayconstant.ChannelType2APIType(channelType)
info := &TaskRelayInfo{
RelayMode: constant.Path2RelayMode(c.Request.URL.Path),
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
BaseUrl: c.GetString("base_url"),
RequestURLPath: c.Request.URL.String(),
ChannelType: channelType,

View File

@@ -7,12 +7,12 @@ import (
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -26,7 +26,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
if constant.ShouldCheckPromptSensitive() {
if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
return nil, err
@@ -74,7 +74,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)

View File

@@ -9,11 +9,11 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
"strings"
)
@@ -59,7 +59,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//}
if constant.ShouldCheckPromptSensitive() {
if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil {
return nil, err
@@ -99,7 +99,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
modelPrice = 0.0025 * modelRatio
}
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
sizeRatio := 1.0

View File

@@ -15,6 +15,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"strconv"
"strings"
"time"
@@ -111,8 +112,8 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
@@ -167,7 +168,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
@@ -207,7 +208,8 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
@@ -421,7 +423,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
if constant.MjActionCheckSuccessEnabled {
if setting.MjActionCheckSuccessEnabled {
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
}
@@ -472,7 +474,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
modelPrice = defaultPrice
}
}
groupRatio := common.GetGroupRatio(group)
groupRatio := setting.GetGroupRatio(group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
@@ -512,7 +514,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent, tokenId, userQuota, 0, false, other)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName,
quota, logContent, tokenId, userQuota, 0, false, group, other)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -15,6 +15,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
"strings"
"time"
@@ -93,24 +94,31 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
}
relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64
var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest)
if constant.ShouldCheckPromptSensitive() {
if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo)
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
}
}
promptTokens, err := getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
// 获取 promptTokens如果上下文中已经存在则直接使用
var promptTokens int
if value, exists := c.Get("prompt_tokens"); exists {
promptTokens = value.(int)
} else {
promptTokens, err = getPromptTokens(textRequest, relayInfo)
// count messages token error 计算promptTokens错误
if err != nil {
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
}
c.Set("prompt_tokens", promptTokens)
}
if !getModelPriceSuccess {
@@ -384,7 +392,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
}
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
//if quota != 0 {
//

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
)
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
@@ -57,7 +58,7 @@ func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWith
relayInfo.UpstreamModelName = rerankRequest.Model
modelPrice, success := common.GetModelPrice(rerankRequest.Model, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64

View File

@@ -16,6 +16,7 @@ import (
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/setting"
)
/*
@@ -48,7 +49,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
}
// 预扣
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
ratio := modelPrice * groupRatio
userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
if err != nil {
@@ -126,7 +127,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
other := make(map[string]interface{})
other["model_price"] = modelPrice
other["group_ratio"] = groupRatio
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0,
modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, relayInfo.Group, other)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}

View File

@@ -10,6 +10,7 @@ import (
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/setting"
)
//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) {
@@ -57,7 +58,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWi
}
//relayInfo.UpstreamModelName = textRequest.Model
modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
var preConsumedQuota int
var ratio float64

View File

@@ -28,10 +28,10 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), controller.EmailBind)
apiRouter.GET("/oauth/telegram/login", middleware.CriticalRateLimit(), controller.TelegramLogin)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.TelegramBind)
apiRouter.GET("/oauth/telegram/bind", middleware.CriticalRateLimit(), controller.TelegramBind)
userRoute := apiRouter.Group("/user")
{
@@ -98,7 +98,8 @@ func SetApiRouter(router *gin.Engine) {
channelRoute.POST("/batch", controller.DeleteChannelBatch)
channelRoute.POST("/fix", controller.FixChannelsAbilities)
channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
channelRoute.POST("/fetch_models", controller.FetchModels)
channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
}
tokenRoute := apiRouter.Group("/token")
tokenRoute.Use(middleware.UserAuth())

View File

@@ -1,12 +1,12 @@
package service
import (
"one-api/constant"
"one-api/setting"
)
func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" {
return constant.ServerAddress
if setting.CustomCallbackAddress == "" {
return setting.ServerAddress
}
return constant.CustomCallbackAddress
return setting.CustomCallbackAddress
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/constant"
"one-api/dto"
relayconstant "one-api/relay/constant"
"one-api/setting"
"strconv"
"strings"
"time"
@@ -167,16 +168,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
}
if !constant.MjAccountFilterEnabled {
if !setting.MjAccountFilterEnabled {
delete(mapResult, "accountFilter")
}
if !constant.MjNotifyEnabled {
if !setting.MjNotifyEnabled {
delete(mapResult, "notifyHook")
}
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// make new request with mapResult
}
if constant.MjModeClearEnabled {
if setting.MjModeClearEnabled {
if prompt, ok := mapResult["prompt"].(string); ok {
prompt = strings.Replace(prompt, "--fast", "", -1)
prompt = strings.Replace(prompt, "--relax", "", -1)

View File

@@ -9,6 +9,7 @@ import (
"one-api/dto"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/setting"
"strings"
"time"
)
@@ -36,7 +37,7 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
completionRatio := common.GetCompletionRatio(modelName)
audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName)
audioCompletionRatio := common.GetAudioCompletionRatio(modelName)
groupRatio := common.GetGroupRatio(relayInfo.Group)
groupRatio := setting.GetGroupRatio(relayInfo.Group)
modelRatio := common.GetModelRatio(modelName)
ratio := groupRatio * modelRatio
@@ -139,7 +140,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
@@ -208,5 +209,5 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, audioRatio, audioCompletionRatio, modelPrice)
model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.PromptTokens, usage.CompletionTokens, logModel,
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other)
tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, relayInfo.Group, other)
}

View File

@@ -3,8 +3,8 @@ package service
import (
"errors"
"fmt"
"one-api/constant"
"one-api/dto"
"one-api/setting"
"strings"
)
@@ -56,7 +56,7 @@ func CheckSensitiveInput(input any) error {
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
func SensitiveWordContains(text string) (bool, []string) {
if len(constant.SensitiveWords) == 0 {
if len(setting.SensitiveWords) == 0 {
return false, nil
}
checkText := strings.ToLower(text)
@@ -75,7 +75,7 @@ func SensitiveWordContains(text string) (bool, []string) {
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
if len(constant.SensitiveWords) == 0 {
if len(setting.SensitiveWords) == 0 {
return false, nil, text
}
checkText := strings.ToLower(text)

View File

@@ -4,7 +4,7 @@ import (
"bytes"
"fmt"
goahocorasick "github.com/anknown/ahocorasick"
"one-api/constant"
"one-api/setting"
"strings"
)
@@ -70,7 +70,7 @@ func InitAc() *goahocorasick.Machine {
func readRunes() [][]rune {
var dict [][]rune
for _, word := range constant.SensitiveWords {
for _, word := range setting.SensitiveWords {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l))

View File

@@ -19,42 +19,40 @@ import (
// tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken
var cl200kTokenEncoder *tiktoken.Tiktoken
var o200kTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() {
common.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
}
defaultTokenEncoder = gpt35TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
defaultTokenEncoder = cl100TokenEncoder
o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
}
for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
tokenEncoderMap[model] = cl100TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = cl200kTokenEncoder
tokenEncoderMap[model] = o200kTokenEncoder
} else {
tokenEncoderMap[model] = gpt4TokenEncoder
tokenEncoderMap[model] = defaultTokenEncoder
}
} else if strings.HasPrefix(model, "o1") {
tokenEncoderMap[model] = o200kTokenEncoder
} else {
tokenEncoderMap[model] = nil
tokenEncoderMap[model] = defaultTokenEncoder
}
}
common.SysLog("token encoders initialized")
}
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") {
return cl200kTokenEncoder
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
return o200kTokenEncoder
}
return defaultTokenEncoder
}
@@ -92,11 +90,11 @@ func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (in
}
// TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream {
return 1000, nil
return 256, nil
}
// 是否统计图片token
if !constant.GetMediaToken {
return 1000, nil
return 256, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {

View File

@@ -5,20 +5,20 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/setting"
"strings"
)
func DoImageRequest(originUrl string) (resp *http.Response, err error) {
if constant.EnableWorker() {
if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
workerUrl := constant.WorkerUrl
workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/"
}
// post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`)
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data))
data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
return http.Get(originUrl)

View File

@@ -1,4 +1,4 @@
package constant
package setting
import (
"encoding/json"

View File

@@ -1,33 +1,47 @@
package common
package setting
import (
"encoding/json"
"errors"
"one-api/common"
)
var GroupRatio = map[string]float64{
var groupRatio = map[string]float64{
"default": 1,
"vip": 1,
"svip": 1,
}
func GetGroupRatioCopy() map[string]float64 {
groupRatioCopy := make(map[string]float64)
for k, v := range groupRatio {
groupRatioCopy[k] = v
}
return groupRatioCopy
}
func ContainsGroupRatio(name string) bool {
_, ok := groupRatio[name]
return ok
}
func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio)
jsonBytes, err := json.Marshal(groupRatio)
if err != nil {
SysError("error marshalling model ratio: " + err.Error())
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateGroupRatioByJSONString(jsonStr string) error {
GroupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &GroupRatio)
groupRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &groupRatio)
}
func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name]
ratio, ok := groupRatio[name]
if !ok {
SysError("group ratio not found: " + name)
common.SysError("group ratio not found: " + name)
return 1
}
return ratio

7
setting/midjourney.go Normal file
View File

@@ -0,0 +1,7 @@
package setting
var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true

View File

@@ -1,4 +1,4 @@
package constant
package setting
var PayAddress = ""
var CustomCallbackAddress = ""

View File

@@ -1,4 +1,4 @@
package constant
package setting
import "strings"

View File

@@ -1,4 +1,4 @@
package constant
package setting
var ServerAddress = "http://localhost:3000"
var WorkerUrl = ""

View File

@@ -0,0 +1,52 @@
package setting
import (
"encoding/json"
"one-api/common"
)
var userUsableGroups = map[string]string{
"default": "默认分组",
"vip": "vip分组",
}
func GetUserUsableGroupsCopy() map[string]string {
copyUserUsableGroups := make(map[string]string)
for k, v := range userUsableGroups {
copyUserUsableGroups[k] = v
}
return copyUserUsableGroups
}
func UserUsableGroups2JSONString() string {
jsonBytes, err := json.Marshal(userUsableGroups)
if err != nil {
common.SysError("error marshalling user groups: " + err.Error())
}
return string(jsonBytes)
}
func UpdateUserUsableGroupsByJSONString(jsonStr string) error {
userUsableGroups = make(map[string]string)
return json.Unmarshal([]byte(jsonStr), &userUsableGroups)
}
func GetUserUsableGroups(userGroup string) map[string]string {
groupsCopy := GetUserUsableGroupsCopy()
if userGroup == "" {
if _, ok := groupsCopy["default"]; !ok {
groupsCopy["default"] = "default"
}
}
// 如果userGroup不在UserUsableGroups中返回UserUsableGroups + userGroup
if _, ok := groupsCopy[userGroup]; !ok {
groupsCopy[userGroup] = "用户分组"
}
// 如果userGroup在UserUsableGroups中返回UserUsableGroups
return groupsCopy
}
func GroupInUserUsableGroups(groupName string) bool {
_, ok := userUsableGroups[groupName]
return ok
}

View File

@@ -162,9 +162,15 @@ const ChannelsTable = () => {
return (
<div>
<Space spacing={2}>
{text?.split(',').map((item, index) => {
return renderGroup(item);
})}
{text?.split(',')
.sort((a, b) => {
if (a === 'default') return -1;
if (b === 'default') return 1;
return a.localeCompare(b);
})
.map((item, index) => {
return renderGroup(item);
})}
</Space>
</div>
);
@@ -507,6 +513,8 @@ const ChannelsTable = () => {
const [selectedChannels, setSelectedChannels] = useState([]);
const [showEditPriority, setShowEditPriority] = useState(false);
const [enableTagMode, setEnableTagMode] = useState(false);
const [showBatchSetTag, setShowBatchSetTag] = useState(false);
const [batchSetTagValue, setBatchSetTagValue] = useState('');
const removeRecord = (record) => {
@@ -968,6 +976,29 @@ const ChannelsTable = () => {
}
};
const batchSetChannelTag = async () => {
if (selectedChannels.length === 0) {
showError(t('请先选择要设置标签的渠道!'));
return;
}
if (batchSetTagValue === '') {
showError(t('标签不能为空!'));
return;
}
let ids = selectedChannels.map(channel => channel.id);
const res = await API.post('/api/channel/batch/tag', {
ids: ids,
tag: batchSetTagValue === '' ? null : batchSetTagValue
});
if (res.data.success) {
showSuccess(t('已为 ${count} 个渠道设置标签!').replace('${count}', res.data.data));
await refresh();
setShowBatchSetTag(false);
} else {
showError(res.data.message);
}
};
return (
<>
<EditTagModal
@@ -1115,11 +1146,11 @@ const ChannelsTable = () => {
</div>
<div style={{ marginTop: 20 }}>
<Space>
<Typography.Text strong>{t('开启批量删除')}</Typography.Text>
<Typography.Text strong>{t('开启批量操作')}</Typography.Text>
<Switch
label={t('开启批量删除')}
label={t('开启批量操作')}
uncheckedText={t('关')}
aria-label={t('是否开启批量删除')}
aria-label={t('是否开启批量操作')}
onChange={(v) => {
setEnableBatchDelete(v);
}}
@@ -1167,7 +1198,17 @@ const ChannelsTable = () => {
loadChannels(0, pageSize, idSort, v);
}}
/>
<Button
disabled={!enableBatchDelete}
theme="light"
type="primary"
style={{ marginRight: 8 }}
onClick={() => setShowBatchSetTag(true)}
>
{t('批量设置标签')}
</Button>
</Space>
</div>
@@ -1201,6 +1242,23 @@ const ChannelsTable = () => {
: null
}
/>
<Modal
title={t('批量设置标签')}
visible={showBatchSetTag}
onOk={batchSetChannelTag}
onCancel={() => setShowBatchSetTag(false)}
maskClosable={false}
centered={true}
>
<div style={{ marginBottom: 20 }}>
<Typography.Text>{t('请输入要设置的标签名称')}</Typography.Text>
</div>
<Input
placeholder={t('请输入标签名称')}
value={batchSetTagValue}
onChange={(v) => setBatchSetTagValue(v)}
/>
</Modal>
</>
);
};

View File

@@ -1,4 +1,4 @@
import React, { useEffect, useState } from 'react';
import React, { useContext, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
API,
@@ -25,7 +25,7 @@ import {
} from '@douyinfe/semi-ui';
import { ITEMS_PER_PAGE } from '../constants';
import {
renderAudioModelPrice,
renderAudioModelPrice, renderGroup,
renderModelPrice, renderModelPriceSimple,
renderNumber,
renderQuota,
@@ -33,6 +33,7 @@ import {
} from '../helpers/render';
import Paragraph from '@douyinfe/semi-ui/lib/es/typography/paragraph';
import { getLogOther } from '../helpers/other.js';
import { StyleContext } from '../context/Style/index.js';
const { Header } = Layout;
@@ -217,6 +218,37 @@ const LogsTable = () => {
);
},
},
{
title: t('分组'),
dataIndex: 'group',
render: (text, record, index) => {
if (record.type === 0 || record.type === 2) {
if (record.group) {
return (
<>
{renderGroup(record.group)}
</>
);
} else {
let other = JSON.parse(record.other);
if (other === null) {
return <></>;
}
if (other.group !== undefined) {
return (
<>
{renderGroup(other.group)}
</>
);
} else {
return <></>;
}
}
} else {
return <></>;
}
},
},
{
title: t('类型'),
dataIndex: 'type',
@@ -375,6 +407,7 @@ const LogsTable = () => {
},
];
const [styleState, styleDispatch] = useContext(StyleContext);
const [logs, setLogs] = useState([]);
const [expandData, setExpandData] = useState({});
const [showStat, setShowStat] = useState(false);
@@ -394,6 +427,7 @@ const LogsTable = () => {
start_timestamp: timestamp2string(getTodayStartTimestamp()),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: '',
group: '',
});
const {
username,
@@ -402,6 +436,7 @@ const LogsTable = () => {
start_timestamp,
end_timestamp,
channel,
group,
} = inputs;
const [stat, setStat] = useState({
@@ -410,13 +445,13 @@ const LogsTable = () => {
});
const handleInputChange = (value, name) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
setInputs(inputs => ({ ...inputs, [name]: value }));
};
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
url = encodeURI(url);
let res = await API.get(url);
const { success, message, data } = res.data;
@@ -430,7 +465,7 @@ const LogsTable = () => {
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
url = encodeURI(url);
let res = await API.get(url);
const { success, message, data } = res.data;
@@ -573,9 +608,9 @@ const LogsTable = () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
if (isAdminUser) {
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
url = `/api/log/?p=${startIdx}&page_size=${pageSize}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}&group=${group}`;
} else {
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&group=${group}`;
}
url = encodeURI(url);
const res = await API.get(url);
@@ -659,10 +694,53 @@ const LogsTable = () => {
</Header>
<Form layout='horizontal' style={{ marginTop: 10 }}>
<>
<Form.Section>
<div style={{ marginBottom: 10 }}>
{
styleState.isMobile ? (
<div>
<Form.DatePicker
field='start_timestamp'
label={t('起始时间')}
style={{ width: 272 }}
initValue={start_timestamp}
type='dateTime'
onChange={(value) => {
console.log(value);
handleInputChange(value, 'start_timestamp')
}}
/>
<Form.DatePicker
field='end_timestamp'
fluid
label={t('结束时间')}
style={{ width: 272 }}
initValue={end_timestamp}
type='dateTime'
onChange={(value) => handleInputChange(value, 'end_timestamp')}
/>
</div>
) : (
<Form.DatePicker
field="range_timestamp"
label={t('时间范围')}
initValue={[start_timestamp, end_timestamp]}
type="dateTimeRange"
name="range_timestamp"
onChange={(value) => {
if (Array.isArray(value) && value.length === 2) {
handleInputChange(value[0], 'start_timestamp');
handleInputChange(value[1], 'end_timestamp');
}
}}
/>
)
}
</div>
</Form.Section>
<Form.Input
field='token_name'
label={t('令牌名称')}
style={{ width: 176 }}
value={token_name}
placeholder={t('可选值')}
name='token_name'
@@ -671,39 +749,24 @@ const LogsTable = () => {
<Form.Input
field='model_name'
label={t('模型名称')}
style={{ width: 176 }}
value={model_name}
placeholder={t('可选值')}
name='model_name'
onChange={(value) => handleInputChange(value, 'model_name')}
/>
<Form.DatePicker
field='start_timestamp'
label={t('起始时间')}
style={{ width: 272 }}
initValue={start_timestamp}
value={start_timestamp}
type='dateTime'
name='start_timestamp'
onChange={(value) => handleInputChange(value, 'start_timestamp')}
/>
<Form.DatePicker
field='end_timestamp'
fluid
label={t('结束时间')}
style={{ width: 272 }}
initValue={end_timestamp}
value={end_timestamp}
type='dateTime'
name='end_timestamp'
onChange={(value) => handleInputChange(value, 'end_timestamp')}
<Form.Input
field='group'
label={t('分组')}
value={group}
placeholder={t('可选值')}
name='group'
onChange={(value) => handleInputChange(value, 'group')}
/>
{isAdminUser && (
<>
<Form.Input
field='channel'
label={t('渠道 ID')}
style={{ width: 176 }}
value={channel}
placeholder={t('可选值')}
name='channel'
@@ -712,7 +775,6 @@ const LogsTable = () => {
<Form.Input
field='username'
label={t('用户名称')}
style={{ width: 176 }}
value={username}
placeholder={t('可选值')}
name='username'

View File

@@ -81,41 +81,24 @@ const ModelPricing = () => {
}
function renderAvailable(available) {
return available ? (
return (
<Popover
content={
<div style={{ padding: 8 }}>{t('您的分组可以使用该模型')}</div>
}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
content={
<div style={{ padding: 8 }}>{t('您的分组可以使用该模型')}</div>
}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
>
<IconVerify style={{ color: 'green' }} size="large" />
<IconVerify style={{ color: 'green' }} size="large" />
</Popover>
) : (
<Popover
content={
<div style={{ padding: 8 }}>{t('您的分组无权使用该模型')}</div>
}
position='top'
key={available}
style={{
backgroundColor: 'rgba(var(--semi-blue-4),1)',
borderColor: 'rgba(var(--semi-blue-4),1)',
color: 'var(--semi-color-white)',
borderWidth: 1,
borderStyle: 'solid',
}}
>
<IconUploadError style={{ color: '#FFA54F' }} size="large" />
</Popover>
);
)
}
const columns = [
@@ -162,36 +145,39 @@ const ModelPricing = () => {
title: t('可用分组'),
dataIndex: 'enable_groups',
render: (text, record, index) => {
// enable_groups is a string array
return (
<Space>
{text.map((group) => {
if (group === selectedGroup) {
return (
<Tag
color='blue'
size='large'
prefixIcon={<IconVerify />}
>
{group}
</Tag>
);
} else {
return (
<Tag
color='blue'
size='large'
onClick={() => {
setSelectedGroup(group);
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
group: group,
ratio: groupRatio[group]
}));
}}
>
{group}
</Tag>
);
if (usableGroup[group]) {
if (group === selectedGroup) {
return (
<Tag
color='blue'
size='large'
prefixIcon={<IconVerify />}
>
{group}
</Tag>
);
} else {
return (
<Tag
color='blue'
size='large'
onClick={() => {
setSelectedGroup(group);
showInfo(t('当前查看的分组为:{{group}},倍率为:{{ratio}}', {
group: group,
ratio: groupRatio[group]
}));
}}
>
{group}
</Tag>
);
}
}
})}
</Space>
@@ -275,6 +261,7 @@ const ModelPricing = () => {
const [loading, setLoading] = useState(true);
const [userState, userDispatch] = useContext(UserContext);
const [groupRatio, setGroupRatio] = useState({});
const [usableGroup, setUsableGroup] = useState({});
const setModelsFormat = (models, groupRatio) => {
for (let i = 0; i < models.length; i++) {
@@ -309,9 +296,10 @@ const ModelPricing = () => {
let url = '';
url = `/api/pricing`;
const res = await API.get(url);
const { success, message, data, group_ratio } = res.data;
const { success, message, data, group_ratio, usable_group } = res.data;
if (success) {
setGroupRatio(group_ratio);
setUsableGroup(usable_group);
setSelectedGroup(userState.user ? userState.user.group : 'default')
setModelsFormat(data, group_ratio);
} else {

View File

@@ -406,7 +406,7 @@ const UsersTable = () => {
if (searchKeyword === '') {
await loadUsers(activePage - 1);
} else {
await searchUsers();
await searchUsers(searchKeyword, searchGroup);
}
};

View File

@@ -59,6 +59,9 @@ export function renderNumber(num) {
}
export function renderQuotaNumberWithDigit(num, digits = 2) {
if (typeof num !== 'number' || isNaN(num)) {
return 0;
}
let displayInCurrency = localStorage.getItem('display_in_currency');
num = num.toFixed(digits);
if (displayInCurrency) {

View File

@@ -180,6 +180,9 @@ export function timestamp2string1(timestamp, dataExportDefaultTime = 'hour') {
let month = (date.getMonth() + 1).toString();
let day = date.getDate().toString();
let hour = date.getHours().toString();
if (day === '24') {
console.log("timestamp", timestamp);
}
if (month.length === 1) {
month = '0' + month;
}

View File

@@ -546,8 +546,8 @@
"是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection",
"开启批量操作": "Enable batch selection",
"是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",

View File

@@ -430,6 +430,8 @@
"一小时后过期": "Expires after one hour",
"一分钟后过期": "Expires after one minute",
"创建新的令牌": "Create New Token",
"令牌分组,默认为用户的分组": "Token group, default is the your's group",
"IP白名单请勿过度信任此功能": "IP whitelist (do not overly trust this function)",
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
"设为无限额度": "Set to unlimited quota",
"更新令牌信息": "Update Token Information",
@@ -546,8 +548,8 @@
"是否用ID排序": "Whether to sort by ID",
"确定?": "Sure?",
"确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
"开启批量删除": "Enable batch selection",
"是否开启批量删除": "Whether to enable batch selection",
"开启批量操作": "Enable batch selection",
"是否开启批量操作": "Whether to enable batch selection",
"确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
"确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
"进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",
@@ -866,8 +868,8 @@
"请选择模式": "Please select mode",
"图片代理方式": "Picture agency method",
"用于替换 https://cdn.discordapp.com 的域名": "The domain name used to replace https://cdn.discordapp.com",
"一个月": "a month",
"一天": "one day",
"一个月": "A month",
"一天": "One day",
"令牌渠道分组选择": "Token channel grouping selection",
"只可使用对应分组包含的模型。": "Only models contained in the corresponding group can be used.",
"渠道分组": "Channel grouping",
@@ -876,7 +878,7 @@
"启用模型限制(非必要,不建议启用)": "Enable model restrictions (not necessary, not recommended)",
"秒": "Second",
"更新令牌后需等待几分钟生效": "It will take a few minutes to take effect after updating the token.",
"一小时": "one hour",
"一小时": "One hour",
"新建数量": "New quantity",
"加载失败,请稍后重试": "Loading failed, please try again later",
"未设置": "Not set",
@@ -1234,5 +1236,9 @@
"应用更改": "Apply changes",
"更多": "Expand more",
"个模型": "models",
"可用模型": "Available models"
"可用模型": "Available models",
"时间范围": "Time range",
"批量设置标签": "Batch set tag",
"请输入要设置的标签名称": "Please enter the tag name to be set",
"请输入标签名称": "Please enter the tag name"
}

View File

@@ -193,14 +193,16 @@ const EditChannel = (props) => {
const fetchUpstreamModelList = async (name) => {
if (inputs['type'] !== 1) {
showError(t('仅支持 OpenAI 接口格式'));
return;
}
// if (inputs['type'] !== 1) {
// showError(t('仅支持 OpenAI 接口格式'));
// return;
// }
setLoading(true);
const models = inputs['models'] || [];
let err = false;
if (isEdit) {
// 如果是编辑模式使用已有的channel id获取模型列表
const res = await API.get('/api/channel/fetch_models/' + channelId);
if (res.data && res.data?.success) {
models.push(...res.data.data);
@@ -208,30 +210,29 @@ const EditChannel = (props) => {
err = true;
}
} else {
// 如果是新建模式,通过后端代理获取模型列表
if (!inputs?.['key']) {
showError(t('请填写密钥'));
err = true;
} else {
try {
const host = new URL((inputs['base_url'] || 'https://api.openai.com'));
const url = `https://${host.hostname}/v1/models`;
const key = inputs['key'];
const res = await axios.get(url, {
headers: {
'Authorization': `Bearer ${key}`
}
const res = await API.post('/api/channel/fetch_models', {
base_url: inputs['base_url'],
key: inputs['key']
});
if (res.data) {
models.push(...res.data.data.map((model) => model.id));
if (res.data && res.data.success) {
models.push(...res.data.data);
} else {
err = true;
}
} catch (error) {
console.error('Error fetching models:', error);
err = true;
}
}
}
if (!err) {
handleInputChange(name, Array.from(new Set(models)));
showSuccess(t('获取模型列表成功'));
@@ -638,7 +639,7 @@ const EditChannel = (props) => {
{inputs.type === 21 && (
<>
<div style={{ marginTop: 10 }}>
<Typography.Text strong>识库 ID</Typography.Text>
<Typography.Text strong><EFBFBD><EFBFBD>识库 ID</Typography.Text>
</div>
<Input
label="知识库 ID"

View File

@@ -143,8 +143,7 @@ const Detail = (props) => {
content: [
{
key: (datum) => datum['Model'],
value: (datum) =>
renderQuotaNumberWithDigit(parseFloat(datum['Usage']), 4),
value: (datum) => renderQuota(datum['rawQuota'] || 0, 4),
},
],
},
@@ -152,22 +151,28 @@ const Detail = (props) => {
content: [
{
key: (datum) => datum['Model'],
value: (datum) => datum['Usage'],
value: (datum) => datum['rawQuota'] || 0,
},
],
updateContent: (array) => {
array.sort((a, b) => b.value - a.value);
let sum = 0;
for (let i = 0; i < array.length; i++) {
sum += parseFloat(array[i].value);
array[i].value = renderQuotaNumberWithDigit(
parseFloat(array[i].value),
4,
);
if (array[i].key == "其他") {
continue;
}
let value = parseFloat(array[i].value);
if (isNaN(value)) {
value = 0;
}
if (array[i].datum && array[i].datum.TimeSum) {
sum = array[i].datum.TimeSum;
}
array[i].value = renderQuota(value, 4);
}
array.unshift({
key: t('总计'),
value: renderQuotaNumberWithDigit(sum, 4),
value: renderQuota(sum, 4),
});
return array;
},
@@ -212,19 +217,8 @@ const Detail = (props) => {
created_at: now.getTime() / 1000,
});
}
// 根据dataExportDefaultTime重制时间粒度
let timeGranularity = 3600;
if (dataExportDefaultTime === 'day') {
timeGranularity = 86400;
} else if (dataExportDefaultTime === 'week') {
timeGranularity = 604800;
}
// sort created_at
data.sort((a, b) => a.created_at - b.created_at);
data.forEach((item) => {
item['created_at'] =
Math.floor(item['created_at'] / timeGranularity) * timeGranularity;
});
updateChartData(data);
} else {
showError(message);
@@ -250,14 +244,14 @@ const Detail = (props) => {
let uniqueModels = new Set();
let totalTokens = 0;
// 收集所有唯一的模型名称和时间点
let uniqueTimes = new Set();
// 收集所有唯一的模型名称
data.forEach(item => {
uniqueModels.add(item.model_name);
uniqueTimes.add(timestamp2string1(item.created_at, dataExportDefaultTime));
totalTokens += item.token_used;
totalQuota += item.quota;
totalTimes += item.count;
});
// 处理颜色映射
const newModelColors = {};
Array.from(uniqueModels).forEach((modelName) => {
@@ -267,56 +261,82 @@ const Detail = (props) => {
});
setModelColors(newModelColors);
// 处理饼图数据
for (let item of data) {
totalQuota += item.quota;
totalTimes += item.count;
let pieItem = newPieData.find((it) => it.type === item.model_name);
if (pieItem) {
pieItem.value += item.count;
} else {
newPieData.push({
type: item.model_name,
value: item.count,
// 按时间和模型聚合数据
let aggregatedData = new Map();
data.forEach(item => {
const timeKey = timestamp2string1(item.created_at, dataExportDefaultTime);
const modelKey = item.model_name;
const key = `${timeKey}-${modelKey}`;
if (!aggregatedData.has(key)) {
aggregatedData.set(key, {
time: timeKey,
model: modelKey,
quota: 0,
count: 0
});
}
const existing = aggregatedData.get(key);
existing.quota += item.quota;
existing.count += item.count;
});
// 处理饼图数据
let modelTotals = new Map();
for (let [_, value] of aggregatedData) {
if (!modelTotals.has(value.model)) {
modelTotals.set(value.model, 0);
}
modelTotals.set(value.model, modelTotals.get(value.model) + value.count);
}
// 处理柱状图数据
let timePoints = Array.from(uniqueTimes);
newPieData = Array.from(modelTotals).map(([model, count]) => ({
type: model,
value: count
}));
// 生成时间点序列
let timePoints = Array.from(new Set([...aggregatedData.values()].map(d => d.time)));
if (timePoints.length < 7) {
// 根据时间粒度生成合适的时间点
const generateTimePoints = () => {
let lastTime = Math.max(...data.map(item => item.created_at));
let points = [];
let interval = dataExportDefaultTime === 'hour' ? 3600
const lastTime = Math.max(...data.map(item => item.created_at));
const interval = dataExportDefaultTime === 'hour' ? 3600
: dataExportDefaultTime === 'day' ? 86400
: 604800;
for (let i = 0; i < 7; i++) {
points.push(timestamp2string1(lastTime - (i * interval), dataExportDefaultTime));
}
return points.reverse();
};
timePoints = generateTimePoints();
timePoints = Array.from({length: 7}, (_, i) =>
timestamp2string1(lastTime - (6-i) * interval, dataExportDefaultTime)
);
}
// 为每个时间点和模型生成数据
// 生成柱状图数据
timePoints.forEach(time => {
Array.from(uniqueModels).forEach(model => {
let existingData = data.find(item =>
timestamp2string1(item.created_at, dataExportDefaultTime) === time &&
item.model_name === model
);
newLineData.push({
// 为每个时间点收集所有模型的数据
let timeData = Array.from(uniqueModels).map(model => {
const key = `${time}-${model}`;
const aggregated = aggregatedData.get(key);
return {
Time: time,
Model: model,
Usage: existingData ? parseFloat(getQuotaWithUnit(existingData.quota)) : 0
});
rawQuota: aggregated?.quota || 0,
Usage: aggregated?.quota ? getQuotaWithUnit(aggregated.quota, 4) : 0
};
});
// 计算该时间点的总计
const timeSum = timeData.reduce((sum, item) => sum + item.rawQuota, 0);
// 按照 rawQuota 从大到小排序
timeData.sort((a, b) => b.rawQuota - a.rawQuota);
// 为每个数据点添加该时间的总计
timeData = timeData.map(item => ({
...item,
TimeSum: timeSum
}));
// 将排序后的数据添加到 newLineData
newLineData.push(...timeData);
});
// 排序

View File

@@ -23,6 +23,7 @@ import {
} from '@douyinfe/semi-ui';
import Title from '@douyinfe/semi-ui/lib/es/typography/title';
import { Divider } from 'semantic-ui-react';
import { useTranslation } from 'react-i18next';
const EditToken = (props) => {
const [isEdit, setIsEdit] = useState(false);
@@ -52,6 +53,7 @@ const EditToken = (props) => {
const [models, setModels] = useState([]);
const [groups, setGroups] = useState([]);
const navigate = useNavigate();
const { t } = useTranslation();
const handleInputChange = (name, value) => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
@@ -87,7 +89,7 @@ const EditToken = (props) => {
}));
setModels(localModelOptions);
} else {
showError(message);
showError(t(message));
}
};
@@ -95,15 +97,13 @@ const EditToken = (props) => {
let res = await API.get(`/api/user/self/groups`);
const { success, message, data } = res.data;
if (success) {
// return data is a map, key is group name, value is group description
// label is group description, value is group name
let localGroupOptions = Object.keys(data).map((group) => ({
label: data[group],
value: group,
}));
setGroups(localGroupOptions);
let localGroupOptions = Object.keys(data).map((group) => ({
label: data[group],
value: group,
}));
setGroups(localGroupOptions);
} else {
showError(message);
showError(t(message));
}
};
@@ -176,7 +176,7 @@ const EditToken = (props) => {
if (localInputs.expired_time !== -1) {
let time = Date.parse(localInputs.expired_time);
if (isNaN(time)) {
showError('过期时间格式错误!');
showError(t('过期时间格式错误!'));
setLoading(false);
return;
}
@@ -189,11 +189,11 @@ const EditToken = (props) => {
});
const { success, message } = res.data;
if (success) {
showSuccess('令牌更新成功!');
showSuccess(t('令牌更新成功!'));
props.refresh();
props.handleClose();
} else {
showError(message);
showError(t(message));
}
} else {
// 处理新增多个令牌的情况
@@ -209,7 +209,7 @@ const EditToken = (props) => {
if (localInputs.expired_time !== -1) {
let time = Date.parse(localInputs.expired_time);
if (isNaN(time)) {
showError('过期时间格式错误!');
showError(t('过期时间格式错误!'));
setLoading(false);
break;
}
@@ -222,14 +222,14 @@ const EditToken = (props) => {
if (success) {
successCount++;
} else {
showError(message);
showError(t(message));
break; // 如果创建失败,终止循环
}
}
if (successCount > 0) {
showSuccess(
`${successCount}令牌创建成功,请在列表页面点击复制获取令牌!`,
t('令牌创建成功,请在列表页面点击复制获取令牌!')
);
props.refresh();
props.handleClose();
@@ -245,7 +245,7 @@ const EditToken = (props) => {
<SideSheet
placement={isEdit ? 'right' : 'left'}
title={
<Title level={3}>{isEdit ? '更新令牌信息' : '创建新的令牌'}</Title>
<Title level={3}>{isEdit ? t('更新令牌信息') : t('创建新的令牌')}</Title>
}
headerStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
bodyStyle={{ borderBottom: '1px solid var(--semi-color-border)' }}
@@ -254,7 +254,7 @@ const EditToken = (props) => {
<div style={{ display: 'flex', justifyContent: 'flex-end' }}>
<Space>
<Button theme='solid' size={'large'} onClick={submit}>
提交
{t('提交')}
</Button>
<Button
theme='solid'
@@ -262,7 +262,7 @@ const EditToken = (props) => {
type={'tertiary'}
onClick={handleCancel}
>
取消
{t('取消')}
</Button>
</Space>
</div>
@@ -274,9 +274,9 @@ const EditToken = (props) => {
<Spin spinning={loading}>
<Input
style={{ marginTop: 20 }}
label='名称'
label={t('名称')}
name='name'
placeholder={'请输入名称'}
placeholder={t('请输入名称')}
onChange={(value) => handleInputChange('name', value)}
value={name}
autoComplete='new-password'
@@ -284,9 +284,9 @@ const EditToken = (props) => {
/>
<Divider />
<DatePicker
label='过期时间'
label={t('过期时间')}
name='expired_time'
placeholder={'请选择过期时间'}
placeholder={t('请选择过期时间')}
onChange={(value) => handleInputChange('expired_time', value)}
value={expired_time}
autoComplete='new-password'
@@ -300,7 +300,7 @@ const EditToken = (props) => {
setExpiredTime(0, 0, 0, 0);
}}
>
永不过期
{t('永不过期')}
</Button>
<Button
type={'tertiary'}
@@ -308,7 +308,7 @@ const EditToken = (props) => {
setExpiredTime(0, 0, 1, 0);
}}
>
一小时
{t('一小时')}
</Button>
<Button
type={'tertiary'}
@@ -316,7 +316,7 @@ const EditToken = (props) => {
setExpiredTime(1, 0, 0, 0);
}}
>
一个月
{t('一个月')}
</Button>
<Button
type={'tertiary'}
@@ -324,7 +324,7 @@ const EditToken = (props) => {
setExpiredTime(0, 1, 0, 0);
}}
>
一天
{t('一天')}
</Button>
</Space>
</div>
@@ -332,17 +332,15 @@ const EditToken = (props) => {
<Divider />
<Banner
type={'warning'}
description={
'注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。'
}
description={t('注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。')}
></Banner>
<div style={{ marginTop: 20 }}>
<Typography.Text>{`额度${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
<Typography.Text>{`${t('额度')}${renderQuotaWithPrompt(remain_quota)}`}</Typography.Text>
</div>
<AutoComplete
style={{ marginTop: 8 }}
name='remain_quota'
placeholder={'请输入额度'}
placeholder={t('请输入额度')}
onChange={(value) => handleInputChange('remain_quota', value)}
value={remain_quota}
autoComplete='new-password'
@@ -362,22 +360,22 @@ const EditToken = (props) => {
{!isEdit && (
<>
<div style={{ marginTop: 20 }}>
<Typography.Text>新建数量</Typography.Text>
<Typography.Text>{t('新建数量')}</Typography.Text>
</div>
<AutoComplete
style={{ marginTop: 8 }}
label='数量'
placeholder={'请选择或输入创建令牌的数量'}
label={t('数量')}
placeholder={t('请选择或输入创建令牌的数量')}
onChange={(value) => handleTokenCountChange(value)}
onSelect={(value) => handleTokenCountChange(value)}
value={tokenCount.toString()}
autoComplete='off'
type='number'
data={[
{ value: 10, label: '10个' },
{ value: 20, label: '20个' },
{ value: 30, label: '30个' },
{ value: 100, label: '100个' },
{ value: 10, label: t('10个') },
{ value: 20, label: t('20个') },
{ value: 30, label: t('30个') },
{ value: 100, label: t('100个') },
]}
disabled={unlimited_quota}
/>
@@ -392,17 +390,17 @@ const EditToken = (props) => {
setUnlimitedQuota();
}}
>
{unlimited_quota ? '取消无限额度' : '设为无限额度'}
{unlimited_quota ? t('取消无限额度') : t('设为无限额度')}
</Button>
</div>
<Divider />
<div style={{ marginTop: 10 }}>
<Typography.Text>IP白名单请勿过度信任此功能</Typography.Text>
<Typography.Text>{t('IP白名单请勿过度信任此功能)')}</Typography.Text>
</div>
<TextArea
label='IP白名单'
label={t('IP白名单')}
name='allow_ips'
placeholder={'允许的IP一行一个'}
placeholder={t('允许的IP一行一个')}
onChange={(value) => {
handleInputChange('allow_ips', value);
}}
@@ -417,16 +415,15 @@ const EditToken = (props) => {
onChange={(e) =>
handleInputChange('model_limits_enabled', e.target.checked)
}
></Checkbox>
<Typography.Text>
启用模型限制非必要不建议启用
</Typography.Text>
>
{t('启用模型限制(非必要,不建议启用)')}
</Checkbox>
</Space>
</div>
<Select
style={{ marginTop: 8 }}
placeholder={'请选择该渠道所支持的模型'}
placeholder={t('请选择该渠道所支持的模型')}
name='models'
required
multiple
@@ -440,12 +437,12 @@ const EditToken = (props) => {
disabled={!model_limits_enabled}
/>
<div style={{ marginTop: 10 }}>
<Typography.Text>令牌分组默认为用户的分组</Typography.Text>
<Typography.Text>{t('令牌分组,默认为用户的分组')}</Typography.Text>
</div>
{groups.length > 0 ?
<Select
style={{ marginTop: 8 }}
placeholder={'令牌分组,默认为用户的分组'}
placeholder={t('令牌分组,默认为用户的分组')}
name='gruop'
required
selection
@@ -458,7 +455,7 @@ const EditToken = (props) => {
/>:
<Select
style={{ marginTop: 8 }}
placeholder={'管理员未设置用户可选分组'}
placeholder={t('管理员未设置用户可选分组')}
name='gruop'
disabled={true}
/>