Compare commits

...

22 Commits

Author SHA1 Message Date
Calcium-Ion
3f19f18dc9 Merge pull request #2278 from seefs001/fix/release-version
fix: release workflow show version
2025-11-23 23:51:32 +08:00
Calcium-Ion
a465597e78 Merge pull request #2277 from seefs001/feature/model_list_fetch
feat: 二次确认添加重定向前模型 && 重定向后模式视为已有模型
2025-11-23 23:51:11 +08:00
Calcium-Ion
dbfcb441f7 Merge pull request #2276 from seefs001/feature/internal_params
feat: embedding param override && internal params
2025-11-23 23:51:00 +08:00
Calcium-Ion
3fb2ba318d Merge pull request #2274 from seefs001/feature/thinking_level
feat: gemini thinking_level && snake params
2025-11-23 23:50:50 +08:00
CaIon
8f039b3a53 feat: Set ContextKeyLocalCountTokens in NativeGeminiEmbeddingHandler for token tracking 2025-11-23 23:50:04 +08:00
CaIon
c939686509 refactor: Deprecate HARM_CATEGORY_CIVIC_INTEGRITY in safety settings 2025-11-23 23:45:48 +08:00
Seefs
07aff1fe02 Merge pull request #1706 from StageDog/feat/discord_oauth
feat: 关联 discord 账号
2025-11-23 18:54:55 +08:00
StageDog
5f27edcd19 fix: IsDiscordIdAlreadyTaken 应该检查软删除记录 2025-11-23 00:07:34 +08:00
Seefs
f47d473e63 fix: release workflow show version 2025-11-22 20:06:13 +08:00
Seefs
7a2bd38700 feat: 重定向后的模型视为已有的模型,附带特殊提示 2025-11-22 19:34:36 +08:00
Seefs
f8c40ecca6 feat: 二次确认添加重定向前模型 2025-11-22 19:23:27 +08:00
StageDog
2bc991685f feat: 针对 discord 登录配置使用新版设置方案 2025-11-22 19:06:53 +08:00
StageDog
87811a0493 feat: 关联 discord 账号 2025-11-22 18:38:24 +08:00
Seefs
0885597427 feat: embedding param override && internal params 2025-11-22 18:27:17 +08:00
CaIon
0952973887 feat: Add CountToken configuration and update token counting logic 2025-11-22 17:15:34 +08:00
Seefs
6b30f042fa feat: gemini thinking_level && snake params 2025-11-22 16:30:46 +08:00
CaIon
efb8f1f5b8 fix: Update GET_MEDIA_TOKEN_NOT_STREAM default value to false 2025-11-22 16:23:37 +08:00
Seefs
de3cf9893d Merge pull request #2268 from chokiproai/main
feat: Add Vietnamese language support
2025-11-22 00:47:32 +08:00
Seefs
fe02e9a066 Merge pull request #2224 from jarvis-u/main
fix: 错误解析responses api中的input字段
2025-11-22 00:31:24 +08:00
CaIon
84745d5ca4 feat: Add ContextKeyLocalCountTokens and update ResponseText2Usage to use context in multiple channels 2025-11-21 18:17:01 +08:00
Chokiproai
cdb1c06ad2 add Vietnamese language support 2025-11-21 10:40:14 +07:00
wujiacheng
d9b5748f80 fix: 错误解析responses api中的input字段 2025-11-14 09:58:39 +08:00
57 changed files with 3729 additions and 258 deletions

View File

@@ -63,7 +63,7 @@
# 是否统计图片token
# GET_MEDIA_TOKEN=true
# 是否在非流stream=false情况下统计图片token
# GET_MEDIA_TOKEN_NOT_STREAM=true
# GET_MEDIA_TOKEN_NOT_STREAM=false
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
# DIFY_DEBUG=true

View File

@@ -22,6 +22,10 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Determine Version
run: |
VERSION=$(git describe --tags)
echo "VERSION=$VERSION" >> $GITHUB_ENV
- uses: oven-sh/setup-bun@v2
with:
bun-version: latest
@@ -31,7 +35,7 @@ jobs:
run: |
cd web
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
@@ -40,13 +44,11 @@ jobs:
- name: Build Backend (amd64)
run: |
go mod download
VERSION=$(git describe --tags)
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-$VERSION
- name: Build Backend (arm64)
run: |
sudo apt-get update
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
VERSION=$(git describe --tags)
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
- name: Release
uses: softprops/action-gh-release@v2
@@ -65,6 +67,10 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Determine Version
run: |
VERSION=$(git describe --tags)
echo "VERSION=$VERSION" >> $GITHUB_ENV
- uses: oven-sh/setup-bun@v2
with:
bun-version: latest
@@ -75,7 +81,7 @@ jobs:
run: |
cd web
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
@@ -84,7 +90,6 @@ jobs:
- name: Build Backend
run: |
go mod download
VERSION=$(git describe --tags)
go build -ldflags "-X 'new-api/common.Version=$VERSION'" -o new-api-macos-$VERSION
- name: Release
uses: softprops/action-gh-release@v2
@@ -105,6 +110,10 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Determine Version
run: |
VERSION=$(git describe --tags)
echo "VERSION=$VERSION" >> $GITHUB_ENV
- uses: oven-sh/setup-bun@v2
with:
bun-version: latest
@@ -114,7 +123,7 @@ jobs:
run: |
cd web
bun install
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
@@ -123,7 +132,6 @@ jobs:
- name: Build Backend
run: |
go mod download
VERSION=$(git describe --tags)
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION'" -o new-api-$VERSION.exe
- name: Release
uses: softprops/action-gh-release@v2
@@ -132,5 +140,3 @@ jobs:
files: new-api-*.exe
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
### 🔐 Authorization and Security
- 😈 Discord authorization login
- 🤖 LinuxDO authorization login
- 📱 Telegram authorization login
- 🔑 OIDC unified authentication

View File

@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
### 🔐 授权与安全
- 😈 Discord 授权登录
- 🤖 LinuxDO 授权登录
- 📱 Telegram 授权登录
- 🔑 OIDC 统一认证

View File

@@ -30,6 +30,11 @@ func printHelp() {
func InitEnv() {
flag.Parse()
envVersion := os.Getenv("VERSION")
if envVersion != "" {
Version = envVersion
}
if *PrintVersion {
fmt.Println(Version)
os.Exit(0)
@@ -111,8 +116,9 @@ func initConstantEnv() {
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
// ForceStreamOption 覆盖请求参数强制返回usage信息
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)

View File

@@ -46,5 +46,7 @@ const (
ContextKeyUsingGroup ContextKey = "group"
ContextKeyUserName ContextKey = "username"
ContextKeyLocalCountTokens ContextKey = "local_count_tokens"
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
)

View File

@@ -4,6 +4,7 @@ var StreamingTimeout int
var DifyDebug bool
var MaxFileDownloadMB int
var ForceStreamOption bool
var CountToken bool
var GetMediaToken bool
var GetMediaTokenNotStream bool
var UpdateTask bool

223
controller/discord.go Normal file
View File

@@ -0,0 +1,223 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/setting/system_setting"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type DiscordResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type DiscordUser struct {
UID string `json:"id"`
ID string `json:"username"`
Name string `json:"global_name"`
}
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := url.Values{}
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
values.Set("code", code)
values.Set("grant_type", "authorization_code")
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
formData := values.Encode()
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res.Body.Close()
var discordResponse DiscordResponse
err = json.NewDecoder(res.Body).Decode(&discordResponse)
if err != nil {
return nil, err
}
if discordResponse.AccessToken == "" {
common.SysError("Discord 获取 Token 失败,请检查设置!")
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
}
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
common.SysLog(err.Error())
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
}
defer res2.Body.Close()
if res2.StatusCode != http.StatusOK {
common.SysError("Discord 获取用户信息失败!请检查设置!")
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
}
var discordUser DiscordUser
err = json.NewDecoder(res2.Body).Decode(&discordUser)
if err != nil {
return nil, err
}
if discordUser.UID == "" || discordUser.ID == "" {
common.SysError("Discord 获取用户信息为空!请检查设置!")
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
}
return &discordUser, nil
}
func DiscordOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
DiscordBind(c)
return
}
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
err := user.FillUserByDiscordId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
if discordUser.ID != "" {
user.Username = discordUser.ID
} else {
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if discordUser.Name != "" {
user.DisplayName = discordUser.Name
} else {
user.DisplayName = "Discord User"
}
err := user.Insert(0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func DiscordBind(c *gin.Context) {
if !system_setting.GetDiscordSettings().Enabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code)
if err != nil {
common.ApiError(c, err)
return
}
user := model.User{
DiscordId: discordUser.UID,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Discord 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
common.ApiError(c, err)
return
}
user.DiscordId = discordUser.UID
err = user.Update(false)
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
}

View File

@@ -52,6 +52,8 @@ func GetStatus(c *gin.Context) {
"email_verification": common.EmailVerificationEnabled,
"github_oauth": common.GitHubOAuthEnabled,
"github_client_id": common.GitHubClientId,
"discord_oauth": system_setting.GetDiscordSettings().Enabled,
"discord_client_id": system_setting.GetDiscordSettings().ClientId,
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
"linuxdo_client_id": common.LinuxDOClientId,
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,

View File

@@ -71,6 +71,14 @@ func UpdateOption(c *gin.Context) {
})
return
}
case "discord.enabled":
if option.Value == "true" && system_setting.GetDiscordSettings().ClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Discord OAuth请先填入 Discord Client Id 以及 Discord Client Secret",
})
return
}
case "oidc.enabled":
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
c.JSON(http.StatusOK, gin.H{

View File

@@ -453,6 +453,7 @@ func GetSelf(c *gin.Context) {
"status": user.Status,
"email": user.Email,
"github_id": user.GitHubId,
"discord_id": user.DiscordId,
"oidc_id": user.OidcId,
"wechat_id": user.WeChatId,
"telegram_id": user.TelegramId,

View File

@@ -42,6 +42,7 @@
| 方法 | 路径 | 鉴权 | 说明 |
|------|------|------|------|
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
| GET | /api/oauth/discord | 公开 | Discord 通用 OAuth 跳转 |
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |

View File

@@ -142,7 +142,38 @@ type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
// TODO Conflict with thinkingbudget.
// ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
}
// UnmarshalJSON allows GeminiThinkingConfig to accept both snake_case and camelCase fields.
func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
type Alias GeminiThinkingConfig
var aux struct {
Alias
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
ThinkingLevelSnake json.RawMessage `json:"thinking_level,omitempty"`
}
if err := common.Unmarshal(data, &aux); err != nil {
return err
}
*c = GeminiThinkingConfig(aux.Alias)
if aux.IncludeThoughtsSnake != nil {
c.IncludeThoughts = *aux.IncludeThoughtsSnake
}
if aux.ThinkingBudgetSnake != nil {
c.ThinkingBudget = aux.ThinkingBudgetSnake
}
if len(aux.ThinkingLevelSnake) > 0 {
c.ThinkingLevel = aux.ThinkingLevelSnake
}
return nil
}
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {

View File

@@ -897,6 +897,12 @@ type Reasoning struct {
Summary string `json:"summary,omitempty"`
}
type Input struct {
Type string `json:"type,omitempty"`
Role string `json:"role,omitempty"`
Content json.RawMessage `json:"content,omitempty"`
}
type MediaInput struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
@@ -915,7 +921,7 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
return nil
}
var inputs []MediaInput
var mediaInputs []MediaInput
// Try string first
// if str, ok := common.GetJsonType(r.Input); ok {
@@ -925,60 +931,74 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
if common.GetJsonType(r.Input) == "string" {
var str string
_ = common.Unmarshal(r.Input, &str)
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
return inputs
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
return mediaInputs
}
// Try array of parts
if common.GetJsonType(r.Input) == "array" {
var array []any
_ = common.Unmarshal(r.Input, &array)
for _, itemAny := range array {
// Already parsed MediaInput
if media, ok := itemAny.(MediaInput); ok {
inputs = append(inputs, media)
continue
var inputs []Input
_ = common.Unmarshal(r.Input, &inputs)
for _, input := range inputs {
if common.GetJsonType(input.Content) == "string" {
var str string
_ = common.Unmarshal(input.Content, &str)
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
}
// Generic map
item, ok := itemAny.(map[string]any)
if !ok {
continue
}
typeVal, ok := item["type"].(string)
if !ok {
continue
}
switch typeVal {
case "input_text":
text, _ := item["text"].(string)
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
case "input_image":
// image_url may be string or object with url field
var imageUrl string
switch v := item["image_url"].(type) {
case string:
imageUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
imageUrl = url
if common.GetJsonType(input.Content) == "array" {
var array []any
_ = common.Unmarshal(input.Content, &array)
for _, itemAny := range array {
// Already parsed MediaContent
if media, ok := itemAny.(MediaInput); ok {
mediaInputs = append(mediaInputs, media)
continue
}
// Generic map
item, ok := itemAny.(map[string]any)
if !ok {
continue
}
typeVal, ok := item["type"].(string)
if !ok {
continue
}
switch typeVal {
case "input_text":
text, _ := item["text"].(string)
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: text})
case "input_image":
// image_url may be string or object with url field
var imageUrl string
switch v := item["image_url"].(type) {
case string:
imageUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
imageUrl = url
}
}
mediaInputs = append(mediaInputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
case "input_file":
// file_url may be string or object with url field
var fileUrl string
switch v := item["file_url"].(type) {
case string:
fileUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
fileUrl = url
}
}
mediaInputs = append(mediaInputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
}
}
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
case "input_file":
// file_url may be string or object with url field
var fileUrl string
switch v := item["file_url"].(type) {
case string:
fileUrl = v
case map[string]any:
if url, ok := v["url"].(string); ok {
fileUrl = url
}
}
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
}
}
}
return inputs
return mediaInputs
}

View File

@@ -27,6 +27,7 @@ type User struct {
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
@@ -539,6 +540,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByDiscordId() error {
if user.DiscordId == "" {
return errors.New("discord id 为空!")
}
DB.Where(User{DiscordId: user.DiscordId}).First(user)
return nil
}
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
@@ -578,6 +587,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsDiscordIdAlreadyTaken(discordId string) bool {
return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
}
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}

View File

@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
@@ -682,7 +682,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
if common.DebugEnabled {
common.SysLog("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}

View File

@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
if err := scanner.Err(); err != nil {
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)

View File

@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
})
if usage.PromptTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
}
return usage, nil
}

View File

@@ -142,7 +142,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
helper.Done(c)
if usage.TotalTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
return usage, nil

View File

@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
})
helper.Done(c)
if usage.TotalTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
}
usage.CompletionTokens += nodeToken
return usage, nil

View File

@@ -32,7 +32,7 @@ var SafetySettingList = []string{
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
//"HARM_CATEGORY_CIVIC_INTEGRITY", This item is deprecated!
}
var ChannelName = "google gemini"

View File

@@ -3,9 +3,9 @@ package gemini
import (
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
@@ -13,8 +13,6 @@ import (
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/types"
"github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
@@ -77,6 +75,8 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
TotalTokens: info.PromptTokens,
}
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
if info.IsGeminiBatchEmbedding {
var geminiResponse dto.GeminiBatchEmbeddingResponse
err = common.Unmarshal(responseBody, &geminiResponse)
@@ -97,80 +97,15 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
}
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{}
var imageCount int
helper.SetEventStreamHeaders(c)
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse dto.GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
// 统计图片数量
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
if part.Text != "" {
responseText.WriteString(part.Text)
}
}
}
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
}
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
// 直接发送 GeminiChatResponse 响应
err = helper.StringData(c, data)
err := helper.StringData(c, data)
if err != nil {
logger.LogError(c, err.Error())
}
info.SendResponseCount++
return true
})
if info.SendResponseCount == 0 {
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
}
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
// 如果usage.CompletionTokens为0则使用本地统计的completion tokens
if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
}
// 移除流式响应结尾的[Done]因为Gemini API没有发送Done的行为
//helper.Done(c)
return usage, nil
}

View File

@@ -954,14 +954,10 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
return nil
}
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
// responseText := ""
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
responseText := strings.Builder{}
func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{}
var imageCount int
finishReason := constant.FinishReasonStop
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse dto.GeminiChatResponse
@@ -971,6 +967,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
return false
}
// 统计图片数量
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && part.InlineData.MimeType != "" {
@@ -982,14 +979,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
}
}
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
@@ -1000,6 +993,45 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
}
}
}
return callback(data, &geminiResponse)
})
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 1400
}
}
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
if usage.TotalTokens > 0 {
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
}
if usage.CompletionTokens <= 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
usage = &dto.Usage{}
}
}
return usage, nil
}
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
finishReason := constant.FinishReasonStop
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
if info.SendResponseCount == 0 {
// send first response
@@ -1015,7 +1047,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
}
finishReason = constant.FinishReasonToolCalls
err = handleStream(c, info, emptyResponse)
err := handleStream(c, info, emptyResponse)
if err != nil {
logger.LogError(c, err.Error())
}
@@ -1025,14 +1057,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response.Choices[0].FinishReason = nil
}
} else {
err = handleStream(c, info, emptyResponse)
err := handleStream(c, info, emptyResponse)
if err != nil {
logger.LogError(c, err.Error())
}
}
}
err = handleStream(c, info, response)
err := handleStream(c, info, response)
if err != nil {
logger.LogError(c, err.Error())
}
@@ -1042,40 +1074,15 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
return true
})
if info.SendResponseCount == 0 {
// 空补全,报错不计费
// empty response, throw an error
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
}
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
if err != nil {
return usage, err
}
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := handleFinalStream(c, info, response)
if err != nil {
common.SysLog("send final response failed: " + err.Error())
handleErr := handleFinalStream(c, info, response)
if handleErr != nil {
common.SysLog("send final response failed: " + handleErr.Error())
}
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
// helper.Done(c)
//}
//resp.Body.Close()
return usage, nil
}

View File

@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
if !containStreamUsage {
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}

View File

@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
} else {
usage, err = palmHandler(c, info, resp)
}

View File

@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
service.CloseResponseBodyGracefully(resp)
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
}
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {

View File

@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
})
if !containStreamUsage {
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
}

View File

@@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -1,12 +1,12 @@
package common
import (
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -30,7 +30,7 @@ type ParamOperation struct {
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
}
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
if len(paramOverride) == 0 {
return jsonData, nil
}
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
// 尝试断言为操作格式
if operations, ok := tryParseOperations(paramOverride); ok {
// 使用新方法
result, err := applyOperations(string(jsonData), operations)
result, err := applyOperations(string(jsonData), operations, conditionContext)
return []byte(result), err
}
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
return nil, false
}
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
if len(conditions) == 0 {
return true, nil // 没有条件,直接通过
}
results := make([]bool, len(conditions))
for i, condition := range conditions {
result, err := checkSingleCondition(jsonStr, condition)
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
if err != nil {
return false, err
}
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
}
}
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
// 处理负数索引
path := processNegativeIndex(jsonStr, condition.Path)
value := gjson.Get(jsonStr, path)
if !value.Exists() && contextJSON != "" {
value = gjson.Get(contextJSON, condition.Path)
}
if !value.Exists() {
if condition.PassMissingKey {
return true, nil
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
}
// 利用gjson的类型解析
targetBytes, err := json.Marshal(condition.Value)
targetBytes, err := common.Marshal(condition.Value)
if err != nil {
return false, fmt.Errorf("failed to marshal condition value: %v", err)
}
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
// applyOperationsLegacy 原参数覆盖方法
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
reqMap := make(map[string]interface{})
err := json.Unmarshal(jsonData, &reqMap)
err := common.Unmarshal(jsonData, &reqMap)
if err != nil {
return nil, err
}
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
reqMap[key] = value
}
return json.Marshal(reqMap)
return common.Marshal(reqMap)
}
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
var contextJSON string
if conditionContext != nil && len(conditionContext) > 0 {
ctxBytes, err := common.Marshal(conditionContext)
if err != nil {
return "", fmt.Errorf("failed to marshal condition context: %v", err)
}
contextJSON = string(ctxBytes)
}
result := jsonStr
for _, op := range operations {
// 检查条件是否满足
ok, err := checkConditions(result, op.Conditions, op.Logic)
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
if err != nil {
return "", err
}
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
var currentMap, newMap map[string]interface{}
// 解析当前值
if err := json.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
return "", err
}
// 解析新值
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
case map[string]interface{}:
newMap = v
default:
jsonBytes, _ := json.Marshal(v)
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
jsonBytes, _ := common.Marshal(v)
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
return "", err
}
}
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
}
return sjson.Set(jsonStr, path, result)
}
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
// 目前内置以下字段:
// - model优先使用上游模型名UpstreamModelName若不存在则回落到原始模型名OriginModelName
// - upstream_model始终为通道映射后的上游模型名。
// - original_model请求最初指定的模型名。
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
if info == nil || info.ChannelMeta == nil {
return nil
}
ctx := make(map[string]interface{})
if info.UpstreamModelName != "" {
ctx["model"] = info.UpstreamModelName
ctx["upstream_model"] = info.UpstreamModelName
}
if info.OriginModelName != "" {
ctx["original_model"] = info.OriginModelName
if _, exists := ctx["model"]; !exists {
ctx["model"] = info.OriginModelName
}
}
if len(ctx) == 0 {
return nil
}
return ctx
}

View File

@@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}
}
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
requestBody := bytes.NewBuffer(jsonData)
statusCodeMappingStr := c.GetString("status_code_mapping")

View File

@@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
if err != nil {
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
}

View File

@@ -30,6 +30,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/image/webp"
)
// return image.Config, format, clean base64 string, error
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
// 去除base64数据的URL前缀如果有
if idx := strings.Index(base64String, ","); idx != -1 {

View File

@@ -62,6 +62,12 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
adminInfo["is_multi_key"] = true
adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
}
isLocalCountTokens := common.GetContextKeyBool(ctx, constant.ContextKeyLocalCountTokens)
if isLocalCountTokens {
adminInfo["local_count_tokens"] = isLocalCountTokens
}
other["admin_info"] = adminInfo
appendRequestPath(ctx, relayInfo, other)
return other

View File

@@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
if fileMeta.Detail == "low" && !isPatchBased {
return baseTokens, nil
}
// Whether to count image tokens at all
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
if !constant.GetMediaTokenNotStream && !stream {
return 3 * baseTokens, nil
}
@@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
fileMeta.Detail = "high"
}
// Whether to count image tokens at all
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
// Decode image to get dimensions
var config image.Config
@@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
}
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
// 是否统计token
if !constant.CountToken {
return 0, nil
}
if meta == nil {
return 0, errors.New("token count meta is nil")
}
if !constant.GetMediaToken {
return 0, nil
}
if !constant.GetMediaTokenNotStream && !info.IsStream {
return 0, nil
}
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
return 0, nil
}
@@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
shouldFetchFiles = false
}
if shouldFetchFiles {
for _, file := range meta.Files {
if strings.HasPrefix(file.OriginData, "http") {
// 是否本地计算媒体token数量
if !constant.GetMediaToken {
shouldFetchFiles = false
}
// 是否在非流模式下本地计算媒体token数量
if !constant.GetMediaTokenNotStream && !info.IsStream {
shouldFetchFiles = false
}
for _, file := range meta.Files {
if strings.HasPrefix(file.OriginData, "http") {
if shouldFetchFiles {
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
if err != nil {
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
@@ -333,28 +344,28 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
file.FileType = types.FileTypeFile
}
file.MimeType = mineType
} else if strings.HasPrefix(file.OriginData, "data:") {
// get mime type from base64 header
parts := strings.SplitN(file.OriginData, ",", 2)
if len(parts) >= 1 {
header := parts[0]
// Extract mime type from "data:mime/type;base64" format
if strings.Contains(header, ":") && strings.Contains(header, ";") {
mimeStart := strings.Index(header, ":") + 1
mimeEnd := strings.Index(header, ";")
if mimeStart < mimeEnd {
mineType := header[mimeStart:mimeEnd]
if strings.HasPrefix(mineType, "image/") {
file.FileType = types.FileTypeImage
} else if strings.HasPrefix(mineType, "video/") {
file.FileType = types.FileTypeVideo
} else if strings.HasPrefix(mineType, "audio/") {
file.FileType = types.FileTypeAudio
} else {
file.FileType = types.FileTypeFile
}
file.MimeType = mineType
}
} else if strings.HasPrefix(file.OriginData, "data:") {
// get mime type from base64 header
parts := strings.SplitN(file.OriginData, ",", 2)
if len(parts) >= 1 {
header := parts[0]
// Extract mime type from "data:mime/type;base64" format
if strings.Contains(header, ":") && strings.Contains(header, ";") {
mimeStart := strings.Index(header, ":") + 1
mimeEnd := strings.Index(header, ";")
if mimeStart < mimeEnd {
mineType := header[mimeStart:mimeEnd]
if strings.HasPrefix(mineType, "image/") {
file.FileType = types.FileTypeImage
} else if strings.HasPrefix(mineType, "video/") {
file.FileType = types.FileTypeVideo
} else if strings.HasPrefix(mineType, "audio/") {
file.FileType = types.FileTypeAudio
} else {
file.FileType = types.FileTypeFile
}
file.MimeType = mineType
}
}
}
@@ -365,7 +376,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
switch file.FileType {
case types.FileTypeImage:
if info.RelayFormat == types.RelayFormatGemini {
tkm += 256
tkm += 520 // gemini per input image tokens
} else {
token, err := getImageToken(file, model, info.IsStream)
if err != nil {

View File

@@ -1,7 +1,10 @@
package service
import (
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/gin-gonic/gin"
)
//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
@@ -16,7 +19,8 @@ import (
// return 0, errors.New("unknown relay mode")
//}
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
func ResponseText2Usage(c *gin.Context, responseText string, modeName string, promptTokens int) *dto.Usage {
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
usage := &dto.Usage{}
usage.PromptTokens = promptTokens
ctkm := CountTextToken(responseText, modeName)

View File

@@ -17,8 +17,7 @@ type GeminiSettings struct {
// 默认配置
var defaultGeminiSettings = GeminiSettings{
SafetySettings: map[string]string{
"default": "OFF",
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
"default": "OFF",
},
VersionSettings: map[string]string{
"default": "v1beta",

View File

@@ -598,6 +598,11 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
return 2.5 / 0.3, false
} else if strings.HasPrefix(name, "gemini-robotics-er-1.5") {
return 2.5 / 0.3, false
} else if strings.HasPrefix(name, "gemini-3-pro") {
if strings.HasPrefix(name, "gemini-3-pro-image") {
return 60, false
}
return 6, false
}
return 4, false
}

View File

@@ -0,0 +1,21 @@
package system_setting
import "github.com/QuantumNous/new-api/setting/config"
type DiscordSettings struct {
Enabled bool `json:"enabled"`
ClientId string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
// 默认配置
var defaultDiscordSettings = DiscordSettings{}
func init() {
// 注册到全局配置管理器
config.GlobalConfig.Register("discord", &defaultDiscordSettings)
}
func GetDiscordSettings() *DiscordSettings {
return &defaultDiscordSettings
}

View File

@@ -192,6 +192,14 @@ function App() {
</Suspense>
}
/>
<Route
path='/oauth/discord'
element={
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
<OAuth2Callback type='discord'></OAuth2Callback>
</Suspense>
}
/>
<Route
path='/oauth/oidc'
element={

View File

@@ -30,6 +30,7 @@ import {
getSystemName,
setUserData,
onGitHubOAuthClicked,
onDiscordOAuthClicked,
onOIDCClicked,
onLinuxDOOAuthClicked,
prepareCredentialRequestOptions,
@@ -53,6 +54,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
import LinuxDoIcon from '../common/logo/LinuxDoIcon';
import TwoFAVerification from './TwoFAVerification';
import { useTranslation } from 'react-i18next';
import { SiDiscord }from 'react-icons/si';
const LoginForm = () => {
let navigate = useNavigate();
@@ -73,6 +75,7 @@ const LoginForm = () => {
const [showEmailLogin, setShowEmailLogin] = useState(false);
const [wechatLoading, setWechatLoading] = useState(false);
const [githubLoading, setGithubLoading] = useState(false);
const [discordLoading, setDiscordLoading] = useState(false);
const [oidcLoading, setOidcLoading] = useState(false);
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
const [emailLoginLoading, setEmailLoginLoading] = useState(false);
@@ -298,6 +301,21 @@ const LoginForm = () => {
}
};
// 包装的Discord登录点击处理
const handleDiscordClick = () => {
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
showInfo(t('请先阅读并同意用户协议和隐私政策'));
return;
}
setDiscordLoading(true);
try {
onDiscordOAuthClicked(status.discord_client_id);
} finally {
// 由于重定向,这里不会执行到,但为了完整性添加
setTimeout(() => setDiscordLoading(false), 3000);
}
};
// 包装的OIDC登录点击处理
const handleOIDCClick = () => {
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
@@ -472,6 +490,19 @@ const LoginForm = () => {
</Button>
)}
{status.discord_oauth && (
<Button
theme='outline'
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
type='tertiary'
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
onClick={handleDiscordClick}
loading={discordLoading}
>
<span className='ml-3'>{t('使用 Discord 继续')}</span>
</Button>
)}
{status.oidc_enabled && (
<Button
theme='outline'
@@ -714,6 +745,7 @@ const LoginForm = () => {
</Form>
{(status.github_oauth ||
status.discord_oauth ||
status.oidc_enabled ||
status.wechat_login ||
status.linuxdo_oauth ||
@@ -849,6 +881,7 @@ const LoginForm = () => {
{showEmailLogin ||
!(
status.github_oauth ||
status.discord_oauth ||
status.oidc_enabled ||
status.wechat_login ||
status.linuxdo_oauth ||

View File

@@ -28,6 +28,7 @@ import {
updateAPI,
getSystemName,
setUserData,
onDiscordOAuthClicked,
} from '../../helpers';
import Turnstile from 'react-turnstile';
import { Button, Card, Checkbox, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui';
@@ -51,6 +52,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
import TelegramLoginButton from 'react-telegram-login/src';
import { UserContext } from '../../context/User';
import { useTranslation } from 'react-i18next';
import { SiDiscord } from 'react-icons/si';
const RegisterForm = () => {
let navigate = useNavigate();
@@ -72,6 +74,7 @@ const RegisterForm = () => {
const [showEmailRegister, setShowEmailRegister] = useState(false);
const [wechatLoading, setWechatLoading] = useState(false);
const [githubLoading, setGithubLoading] = useState(false);
const [discordLoading, setDiscordLoading] = useState(false);
const [oidcLoading, setOidcLoading] = useState(false);
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
const [emailRegisterLoading, setEmailRegisterLoading] = useState(false);
@@ -264,6 +267,15 @@ const RegisterForm = () => {
}
};
const handleDiscordClick = () => {
setDiscordLoading(true);
try {
onDiscordOAuthClicked(status.discord_client_id);
} finally {
setTimeout(() => setDiscordLoading(false), 3000);
}
};
const handleOIDCClick = () => {
setOidcLoading(true);
try {
@@ -377,6 +389,19 @@ const RegisterForm = () => {
</Button>
)}
{status.discord_oauth && (
<Button
theme='outline'
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
type='tertiary'
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
onClick={handleDiscordClick}
loading={discordLoading}
>
<span className='ml-3'>{t('使用 Discord 继续')}</span>
</Button>
)}
{status.oidc_enabled && (
<Button
theme='outline'
@@ -591,6 +616,7 @@ const RegisterForm = () => {
</Form>
{(status.github_oauth ||
status.discord_oauth ||
status.oidc_enabled ||
status.wechat_login ||
status.linuxdo_oauth ||
@@ -686,6 +712,7 @@ const RegisterForm = () => {
{showEmailRegister ||
!(
status.github_oauth ||
status.discord_oauth ||
status.oidc_enabled ||
status.wechat_login ||
status.linuxdo_oauth ||

View File

@@ -20,7 +20,7 @@ For commercial licensing, please contact support@quantumnous.com
import React from 'react';
import { Button, Dropdown } from '@douyinfe/semi-ui';
import { Languages } from 'lucide-react';
import { CN, GB, FR, RU, JP } from 'country-flag-icons/react/3x2';
import { CN, GB, FR, RU, JP, VN } from 'country-flag-icons/react/3x2';
const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
return (
@@ -65,6 +65,13 @@ const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
<RU title='Русский' className='!w-5 !h-auto' />
<span>Русский</span>
</Dropdown.Item>
<Dropdown.Item
onClick={() => onLanguageChange('vi')}
className={`!flex !items-center !gap-2 !px-3 !py-1.5 !text-sm !text-semi-color-text-0 dark:!text-gray-200 ${currentLang === 'vi' ? '!bg-semi-color-primary-light-default dark:!bg-blue-600 !font-semibold' : 'hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-600'}`}
>
<VN title='Tiếng Việt' className='!w-5 !h-auto' />
<span>Tiếng Việt</span>
</Dropdown.Item>
</Dropdown.Menu>
}
>

View File

@@ -52,6 +52,9 @@ const SystemSetting = () => {
GitHubOAuthEnabled: '',
GitHubClientId: '',
GitHubClientSecret: '',
'discord.enabled': '',
'discord.client_id': '',
'discord.client_secret': '',
'oidc.enabled': '',
'oidc.client_id': '',
'oidc.client_secret': '',
@@ -179,6 +182,7 @@ const SystemSetting = () => {
case 'EmailAliasRestrictionEnabled':
case 'SMTPSSLEnabled':
case 'LinuxDOOAuthEnabled':
case 'discord.enabled':
case 'oidc.enabled':
case 'passkey.enabled':
case 'passkey.allow_insecure_origin':
@@ -473,6 +477,27 @@ const SystemSetting = () => {
}
};
const submitDiscordOAuth = async () => {
const options = [];
if (originInputs['discord.client_id'] !== inputs['discord.client_id']) {
options.push({ key: 'discord.client_id', value: inputs['discord.client_id'] });
}
if (
originInputs['discord.client_secret'] !== inputs['discord.client_secret'] &&
inputs['discord.client_secret'] !== ''
) {
options.push({
key: 'discord.client_secret',
value: inputs['discord.client_secret'],
});
}
if (options.length > 0) {
await updateOptions(options);
}
};
const submitOIDCSettings = async () => {
if (inputs['oidc.well_known'] && inputs['oidc.well_known'] !== '') {
if (
@@ -1014,6 +1039,15 @@ const SystemSetting = () => {
>
{t('允许通过 GitHub 账户登录 & 注册')}
</Form.Checkbox>
<Form.Checkbox
field='discord.enabled'
noLabel
onChange={(e) =>
handleCheckboxChange('discord.enabled', e)
}
>
{t('允许通过 Discord 账户登录 & 注册')}
</Form.Checkbox>
<Form.Checkbox
field='LinuxDOOAuthEnabled'
noLabel
@@ -1410,6 +1444,37 @@ const SystemSetting = () => {
</Button>
</Form.Section>
</Card>
<Card>
<Form.Section text={t('配置 Discord OAuth')}>
<Text>{t('用以支持通过 Discord 进行登录注册')}</Text>
<Banner
type='info'
description={`${t('Homepage URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')}${t('Authorization callback URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')}/oauth/discord`}
style={{ marginBottom: 20, marginTop: 16 }}
/>
<Row
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field="['discord.client_id']"
label={t('Discord Client ID')}
/>
</Col>
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
<Form.Input
field="['discord.client_secret']"
label={t('Discord Client Secret')}
type='password'
placeholder={t('敏感信息不会发送到前端显示')}
/>
</Col>
</Row>
<Button onClick={submitDiscordOAuth}>
{t('保存 Discord OAuth 设置')}
</Button>
</Form.Section>
</Card>
<Card>
<Form.Section text={t('配置 Linux DO OAuth')}>
<Text>

View File

@@ -38,13 +38,14 @@ import {
IconLock,
IconDelete,
} from '@douyinfe/semi-icons';
import { SiTelegram, SiWechat, SiLinux } from 'react-icons/si';
import { SiTelegram, SiWechat, SiLinux, SiDiscord } from 'react-icons/si';
import { UserPlus, ShieldCheck } from 'lucide-react';
import TelegramLoginButton from 'react-telegram-login';
import {
onGitHubOAuthClicked,
onOIDCClicked,
onLinuxDOOAuthClicked,
onDiscordOAuthClicked,
} from '../../../../helpers';
import TwoFASetting from '../components/TwoFASetting';
@@ -247,6 +248,47 @@ const AccountManagement = ({
</div>
</Card>
{/* Discord绑定 */}
<Card className='!rounded-xl'>
<div className='flex items-center justify-between gap-3'>
<div className='flex items-center flex-1 min-w-0'>
<div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
<SiDiscord
size={20}
className='text-slate-600 dark:text-slate-300'
/>
</div>
<div className='flex-1 min-w-0'>
<div className='font-medium text-gray-900'>
{t('Discord')}
</div>
<div className='text-sm text-gray-500 truncate'>
{renderAccountInfo(
userState.user?.discord_id,
t('Discord ID'),
)}
</div>
</div>
</div>
<div className='flex-shrink-0'>
<Button
type='primary'
theme='outline'
size='small'
onClick={() =>
onDiscordOAuthClicked(status.discord_client_id)
}
disabled={
isBound(userState.user?.discord_id) ||
!status.discord_oauth
}
>
{status.discord_oauth ? t('绑定') : t('未启用')}
</Button>
</div>
</div>
</Card>
{/* OIDC绑定 */}
<Card className='!rounded-xl'>
<div className='flex items-center justify-between gap-3'>

View File

@@ -190,6 +190,30 @@ const EditChannelModal = (props) => {
const [keyMode, setKeyMode] = useState('append'); // 密钥模式replace覆盖或 append追加
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
const [doubaoApiEditUnlocked, setDoubaoApiEditUnlocked] = useState(false); // 豆包渠道自定义 API 地址隐藏入口
const redirectModelList = useMemo(() => {
const mapping = inputs.model_mapping;
if (typeof mapping !== 'string') return [];
const trimmed = mapping.trim();
if (!trimmed) return [];
try {
const parsed = JSON.parse(trimmed);
if (
!parsed ||
typeof parsed !== 'object' ||
Array.isArray(parsed)
) {
return [];
}
const values = Object.values(parsed)
.map((value) =>
typeof value === 'string' ? value.trim() : undefined,
)
.filter((value) => value);
return Array.from(new Set(values));
} catch (error) {
return [];
}
}, [inputs.model_mapping]);
// 密钥显示状态
const [keyDisplayState, setKeyDisplayState] = useState({
@@ -220,6 +244,8 @@ const EditChannelModal = (props) => {
];
const formContainerRef = useRef(null);
const doubaoApiClickCountRef = useRef(0);
const initialModelsRef = useRef([]);
const initialModelMappingRef = useRef('');
// 2FA状态更新辅助函数
const updateTwoFAState = (updates) => {
@@ -595,6 +621,10 @@ const EditChannelModal = (props) => {
system_prompt: data.system_prompt,
system_prompt_override: data.system_prompt_override || false,
});
initialModelsRef.current = (data.models || [])
.map((model) => (model || '').trim())
.filter(Boolean);
initialModelMappingRef.current = data.model_mapping || '';
// console.log(data);
} else {
showError(message);
@@ -830,6 +860,13 @@ const EditChannelModal = (props) => {
}
}, [props.visible, channelId]);
useEffect(() => {
if (!isEdit) {
initialModelsRef.current = [];
initialModelMappingRef.current = '';
}
}, [isEdit, props.visible]);
// 统一的模态框重置函数
const resetModalState = () => {
formApiRef.current?.reset();
@@ -903,6 +940,80 @@ const EditChannelModal = (props) => {
})();
};
const confirmMissingModelMappings = (missingModels) =>
new Promise((resolve) => {
const modal = Modal.confirm({
title: t('模型未加入列表,可能无法调用'),
content: (
<div className='text-sm leading-6'>
<div>
{t(
'模型重定向里的下列模型尚未添加到“模型”列表,调用时会因为缺少可用模型而失败:',
)}
</div>
<div className='font-mono text-xs break-all text-red-600 mt-1'>
{missingModels.join(', ')}
</div>
<div className='mt-2'>
{t(
'你可以在“自定义模型名称”处手动添加它们,然后点击填入后再提交,或者直接使用下方操作自动处理。',
)}
</div>
</div>
),
centered: true,
footer: (
<Space align='center' className='w-full justify-end'>
<Button
type='tertiary'
onClick={() => {
modal.destroy();
resolve('cancel');
}}
>
{t('返回修改')}
</Button>
<Button
type='primary'
theme='light'
onClick={() => {
modal.destroy();
resolve('submit');
}}
>
{t('直接提交')}
</Button>
<Button
type='primary'
theme='solid'
onClick={() => {
modal.destroy();
resolve('add');
}}
>
{t('添加后提交')}
</Button>
</Space>
),
});
});
const hasModelConfigChanged = (normalizedModels, modelMappingStr) => {
if (!isEdit) return true;
const initialModels = initialModelsRef.current;
if (normalizedModels.length !== initialModels.length) {
return true;
}
for (let i = 0; i < normalizedModels.length; i++) {
if (normalizedModels[i] !== initialModels[i]) {
return true;
}
}
const normalizedMapping = (modelMappingStr || '').trim();
const initialMapping = (initialModelMappingRef.current || '').trim();
return normalizedMapping !== initialMapping;
};
const submit = async () => {
const formValues = formApiRef.current ? formApiRef.current.getValues() : {};
let localInputs = { ...formValues };
@@ -986,14 +1097,55 @@ const EditChannelModal = (props) => {
showInfo(t('请输入API地址'));
return;
}
if (
localInputs.model_mapping &&
localInputs.model_mapping !== '' &&
!verifyJSON(localInputs.model_mapping)
) {
showInfo(t('模型映射必须是合法的 JSON 格式!'));
return;
const hasModelMapping =
typeof localInputs.model_mapping === 'string' &&
localInputs.model_mapping.trim() !== '';
let parsedModelMapping = null;
if (hasModelMapping) {
if (!verifyJSON(localInputs.model_mapping)) {
showInfo(t('模型映射必须是合法的 JSON 格式!'));
return;
}
try {
parsedModelMapping = JSON.parse(localInputs.model_mapping);
} catch (error) {
showInfo(t('模型映射必须是合法的 JSON 格式!'));
return;
}
}
const normalizedModels = (localInputs.models || [])
.map((model) => (model || '').trim())
.filter(Boolean);
localInputs.models = normalizedModels;
if (
parsedModelMapping &&
typeof parsedModelMapping === 'object' &&
!Array.isArray(parsedModelMapping)
) {
const modelSet = new Set(normalizedModels);
const missingModels = Object.keys(parsedModelMapping)
.map((key) => (key || '').trim())
.filter((key) => key && !modelSet.has(key));
const shouldPromptMissing =
missingModels.length > 0 &&
hasModelConfigChanged(normalizedModels, localInputs.model_mapping);
if (shouldPromptMissing) {
const confirmAction = await confirmMissingModelMappings(missingModels);
if (confirmAction === 'cancel') {
return;
}
if (confirmAction === 'add') {
const updatedModels = Array.from(
new Set([...normalizedModels, ...missingModels]),
);
localInputs.models = updatedModels;
handleInputChange('models', updatedModels);
}
}
}
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
localInputs.base_url = localInputs.base_url.slice(
0,
@@ -2916,6 +3068,7 @@ const EditChannelModal = (props) => {
visible={modelModalVisible}
models={fetchedModels}
selected={inputs.models}
redirectModels={redirectModelList}
onConfirm={(selectedModels) => {
handleInputChange('models', selectedModels);
showSuccess(t('模型列表已更新'));

View File

@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
For commercial licensing, please contact support@quantumnous.com
*/
import React, { useState, useEffect } from 'react';
import React, { useState, useEffect, useMemo } from 'react';
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
import {
Modal,
@@ -28,12 +28,13 @@ import {
Empty,
Tabs,
Collapse,
Tooltip,
} from '@douyinfe/semi-ui';
import {
IllustrationNoResult,
IllustrationNoResultDark,
} from '@douyinfe/semi-illustrations';
import { IconSearch } from '@douyinfe/semi-icons';
import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
import { useTranslation } from 'react-i18next';
import { getModelCategories } from '../../../../helpers/render';
@@ -41,6 +42,7 @@ const ModelSelectModal = ({
visible,
models = [],
selected = [],
redirectModels = [],
onConfirm,
onCancel,
}) => {
@@ -50,15 +52,54 @@ const ModelSelectModal = ({
const [activeTab, setActiveTab] = useState('new');
const isMobile = useIsMobile();
const normalizeModelName = (model) =>
typeof model === 'string' ? model.trim() : '';
const normalizedRedirectModels = useMemo(
() =>
Array.from(
new Set(
(redirectModels || [])
.map((model) => normalizeModelName(model))
.filter(Boolean),
),
),
[redirectModels],
);
const normalizedSelectedSet = useMemo(() => {
const set = new Set();
(selected || []).forEach((model) => {
const normalized = normalizeModelName(model);
if (normalized) {
set.add(normalized);
}
});
return set;
}, [selected]);
const classificationSet = useMemo(() => {
const set = new Set(normalizedSelectedSet);
normalizedRedirectModels.forEach((model) => set.add(model));
return set;
}, [normalizedSelectedSet, normalizedRedirectModels]);
const redirectOnlySet = useMemo(() => {
const set = new Set();
normalizedRedirectModels.forEach((model) => {
if (!normalizedSelectedSet.has(model)) {
set.add(model);
}
});
return set;
}, [normalizedRedirectModels, normalizedSelectedSet]);
const filteredModels = models.filter((m) =>
m.toLowerCase().includes(keyword.toLowerCase()),
String(m || '').toLowerCase().includes(keyword.toLowerCase()),
);
// 分类模型:新获取的模型和已有模型
const newModels = filteredModels.filter((model) => !selected.includes(model));
const isExistingModel = (model) =>
classificationSet.has(normalizeModelName(model));
const newModels = filteredModels.filter((model) => !isExistingModel(model));
const existingModels = filteredModels.filter((model) =>
selected.includes(model),
isExistingModel(model),
);
// 同步外部选中值
@@ -228,7 +269,20 @@ const ModelSelectModal = ({
<div className='grid grid-cols-2 gap-x-4'>
{categoryData.models.map((model) => (
<Checkbox key={model} value={model} className='my-1'>
{model}
<span className='flex items-center gap-2'>
<span>{model}</span>
{redirectOnlySet.has(normalizeModelName(model)) && (
<Tooltip
position='top'
content={t('来自模型重定向,尚未加入模型列表')}
>
<IconInfoCircle
size='small'
className='text-amber-500 cursor-help'
/>
</Tooltip>
)}
</span>
</Checkbox>
))}
</div>

View File

@@ -72,6 +72,7 @@ const EditUserModal = (props) => {
password: '',
github_id: '',
oidc_id: '',
discord_id: '',
wechat_id: '',
telegram_id: '',
email: '',
@@ -332,6 +333,7 @@ const EditUserModal = (props) => {
<Row gutter={12}>
{[
'github_id',
'discord_id',
'oidc_id',
'wechat_id',
'email',

View File

@@ -231,6 +231,17 @@ export async function getOAuthState() {
}
}
export async function onDiscordOAuthClicked(client_id) {
const state = await getOAuthState();
if (!state) return;
const redirect_uri = `${window.location.origin}/oauth/discord`;
const response_type = 'code';
const scope = 'identify+openid';
window.open(
`https://discord.com/oauth2/authorize?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`,
);
}
export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
const state = await getOAuthState();
if (!state) return;

View File

@@ -482,6 +482,18 @@ export const useLogsData = () => {
value: other.request_path,
});
}
if (isAdminUser) {
let localCountMode = '';
if (other?.admin_info?.local_count_tokens) {
localCountMode = t('本地计费');
} else {
localCountMode = t('上游返回');
}
expandDataLocal.push({
key: t('计费模式'),
value: localCountMode,
});
}
expandDatesLocal[logs[i].key] = expandDataLocal;
}

View File

@@ -26,6 +26,7 @@ import frTranslation from './locales/fr.json';
import zhTranslation from './locales/zh.json';
import ruTranslation from './locales/ru.json';
import jaTranslation from './locales/ja.json';
import viTranslation from './locales/vi.json';
i18n
.use(LanguageDetector)
@@ -38,6 +39,7 @@ i18n
fr: frTranslation,
ru: ruTranslation,
ja: jaTranslation,
vi: viTranslation,
},
fallbackLng: 'zh',
interpolation: {

2700
web/src/i18n/locales/vi.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -257,6 +257,7 @@
"余额充值管理": "余额充值管理",
"你似乎并没有修改什么": "你似乎并没有修改什么",
"使用 GitHub 继续": "使用 GitHub 继续",
"使用 Discord 继续": "使用 Discord 继续",
"使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}",
"使用 LinuxDO 继续": "使用 LinuxDO 继续",
"使用 OIDC 继续": "使用 OIDC 继续",