mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-04 22:57:18 +00:00
Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f19f18dc9 | ||
|
|
a465597e78 | ||
|
|
dbfcb441f7 | ||
|
|
3fb2ba318d | ||
|
|
8f039b3a53 | ||
|
|
c939686509 | ||
|
|
07aff1fe02 | ||
|
|
5f27edcd19 | ||
|
|
f47d473e63 | ||
|
|
7a2bd38700 | ||
|
|
f8c40ecca6 | ||
|
|
2bc991685f | ||
|
|
87811a0493 | ||
|
|
0885597427 | ||
|
|
0952973887 | ||
|
|
6b30f042fa | ||
|
|
efb8f1f5b8 | ||
|
|
de3cf9893d | ||
|
|
fe02e9a066 | ||
|
|
84745d5ca4 | ||
|
|
cdb1c06ad2 | ||
|
|
d9b5748f80 |
@@ -63,7 +63,7 @@
|
||||
# 是否统计图片token
|
||||
# GET_MEDIA_TOKEN=true
|
||||
# 是否在非流(stream=false)情况下统计图片token
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=false
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
|
||||
|
||||
24
.github/workflows/release.yml
vendored
24
.github/workflows/release.yml
vendored
@@ -22,6 +22,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -31,7 +35,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -40,13 +44,11 @@ jobs:
|
||||
- name: Build Backend (amd64)
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-$VERSION
|
||||
- name: Build Backend (arm64)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
VERSION=$(git describe --tags)
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -65,6 +67,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -75,7 +81,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -84,7 +90,6 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-X 'new-api/common.Version=$VERSION'" -o new-api-macos-$VERSION
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -105,6 +110,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -114,7 +123,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -123,7 +132,6 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION'" -o new-api-$VERSION.exe
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -132,5 +140,3 @@ jobs:
|
||||
files: new-api-*.exe
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
|
||||
|
||||
### 🔐 Authorization and Security
|
||||
|
||||
- 😈 Discord authorization login
|
||||
- 🤖 LinuxDO authorization login
|
||||
- 📱 Telegram authorization login
|
||||
- 🔑 OIDC unified authentication
|
||||
|
||||
@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
|
||||
|
||||
### 🔐 授权与安全
|
||||
|
||||
- 😈 Discord 授权登录
|
||||
- 🤖 LinuxDO 授权登录
|
||||
- 📱 Telegram 授权登录
|
||||
- 🔑 OIDC 统一认证
|
||||
|
||||
@@ -30,6 +30,11 @@ func printHelp() {
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
envVersion := os.Getenv("VERSION")
|
||||
if envVersion != "" {
|
||||
Version = envVersion
|
||||
}
|
||||
|
||||
if *PrintVersion {
|
||||
fmt.Println(Version)
|
||||
os.Exit(0)
|
||||
@@ -111,8 +116,9 @@ func initConstantEnv() {
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
@@ -46,5 +46,7 @@ const (
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
|
||||
ContextKeyLocalCountTokens ContextKey = "local_count_tokens"
|
||||
|
||||
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
var ForceStreamOption bool
|
||||
var CountToken bool
|
||||
var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
|
||||
223
controller/discord.go
Normal file
223
controller/discord.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DiscordResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var discordResponse DiscordResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
common.SysError("Discord 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("Discord 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var discordUser DiscordUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
common.SysError("Discord 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &discordUser, nil
|
||||
}
|
||||
|
||||
func DiscordOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
err := user.FillUserByDiscordId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
user.Username = discordUser.ID
|
||||
} else {
|
||||
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if discordUser.Name != "" {
|
||||
user.DisplayName = discordUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Discord User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func DiscordBind(c *gin.Context) {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Discord 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.DiscordId = discordUser.UID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
@@ -52,6 +52,8 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"discord_oauth": system_setting.GetDiscordSettings().Enabled,
|
||||
"discord_client_id": system_setting.GetDiscordSettings().ClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||
|
||||
@@ -71,6 +71,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "discord.enabled":
|
||||
if option.Value == "true" && system_setting.GetDiscordSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 Discord OAuth,请先填入 Discord Client Id 以及 Discord Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "oidc.enabled":
|
||||
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -453,6 +453,7 @@ func GetSelf(c *gin.Context) {
|
||||
"status": user.Status,
|
||||
"email": user.Email,
|
||||
"github_id": user.GitHubId,
|
||||
"discord_id": user.DiscordId,
|
||||
"oidc_id": user.OidcId,
|
||||
"wechat_id": user.WeChatId,
|
||||
"telegram_id": user.TelegramId,
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||
| GET | /api/oauth/discord | 公开 | Discord 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||
|
||||
@@ -142,7 +142,38 @@ type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
// TODO Conflict with thinkingbudget.
|
||||
// ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
|
||||
ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiThinkingConfig to accept both snake_case and camelCase fields.
|
||||
func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiThinkingConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
|
||||
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
|
||||
ThinkingLevelSnake json.RawMessage `json:"thinking_level,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GeminiThinkingConfig(aux.Alias)
|
||||
|
||||
if aux.IncludeThoughtsSnake != nil {
|
||||
c.IncludeThoughts = *aux.IncludeThoughtsSnake
|
||||
}
|
||||
|
||||
if aux.ThinkingBudgetSnake != nil {
|
||||
c.ThinkingBudget = aux.ThinkingBudgetSnake
|
||||
}
|
||||
|
||||
if len(aux.ThinkingLevelSnake) > 0 {
|
||||
c.ThinkingLevel = aux.ThinkingLevelSnake
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
|
||||
|
||||
@@ -897,6 +897,12 @@ type Reasoning struct {
|
||||
Summary string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type MediaInput struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
@@ -915,7 +921,7 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||
return nil
|
||||
}
|
||||
|
||||
var inputs []MediaInput
|
||||
var mediaInputs []MediaInput
|
||||
|
||||
// Try string first
|
||||
// if str, ok := common.GetJsonType(r.Input); ok {
|
||||
@@ -925,60 +931,74 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||
if common.GetJsonType(r.Input) == "string" {
|
||||
var str string
|
||||
_ = common.Unmarshal(r.Input, &str)
|
||||
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||
return inputs
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
|
||||
return mediaInputs
|
||||
}
|
||||
|
||||
// Try array of parts
|
||||
if common.GetJsonType(r.Input) == "array" {
|
||||
var array []any
|
||||
_ = common.Unmarshal(r.Input, &array)
|
||||
for _, itemAny := range array {
|
||||
// Already parsed MediaInput
|
||||
if media, ok := itemAny.(MediaInput); ok {
|
||||
inputs = append(inputs, media)
|
||||
continue
|
||||
var inputs []Input
|
||||
_ = common.Unmarshal(r.Input, &inputs)
|
||||
for _, input := range inputs {
|
||||
if common.GetJsonType(input.Content) == "string" {
|
||||
var str string
|
||||
_ = common.Unmarshal(input.Content, &str)
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
|
||||
}
|
||||
// Generic map
|
||||
item, ok := itemAny.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeVal, ok := item["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch typeVal {
|
||||
case "input_text":
|
||||
text, _ := item["text"].(string)
|
||||
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
|
||||
case "input_image":
|
||||
// image_url may be string or object with url field
|
||||
var imageUrl string
|
||||
switch v := item["image_url"].(type) {
|
||||
case string:
|
||||
imageUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
imageUrl = url
|
||||
|
||||
if common.GetJsonType(input.Content) == "array" {
|
||||
var array []any
|
||||
_ = common.Unmarshal(input.Content, &array)
|
||||
for _, itemAny := range array {
|
||||
// Already parsed MediaContent
|
||||
if media, ok := itemAny.(MediaInput); ok {
|
||||
mediaInputs = append(mediaInputs, media)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generic map
|
||||
item, ok := itemAny.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
typeVal, ok := item["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch typeVal {
|
||||
case "input_text":
|
||||
text, _ := item["text"].(string)
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: text})
|
||||
case "input_image":
|
||||
// image_url may be string or object with url field
|
||||
var imageUrl string
|
||||
switch v := item["image_url"].(type) {
|
||||
case string:
|
||||
imageUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
imageUrl = url
|
||||
}
|
||||
}
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||
case "input_file":
|
||||
// file_url may be string or object with url field
|
||||
var fileUrl string
|
||||
switch v := item["file_url"].(type) {
|
||||
case string:
|
||||
fileUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
fileUrl = url
|
||||
}
|
||||
}
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||
}
|
||||
}
|
||||
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||
case "input_file":
|
||||
// file_url may be string or object with url field
|
||||
var fileUrl string
|
||||
switch v := item["file_url"].(type) {
|
||||
case string:
|
||||
fileUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
fileUrl = url
|
||||
}
|
||||
}
|
||||
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputs
|
||||
return mediaInputs
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
|
||||
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
@@ -539,6 +540,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByDiscordId() error {
|
||||
if user.DiscordId == "" {
|
||||
return errors.New("discord id 为空!")
|
||||
}
|
||||
DB.Where(User{DiscordId: user.DiscordId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
@@ -578,6 +587,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsDiscordIdAlreadyTaken(discordId string) bool {
|
||||
return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
@@ -682,7 +682,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
|
||||
@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
||||
helper.Done(c)
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
|
||||
@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return usage, nil
|
||||
|
||||
@@ -32,7 +32,7 @@ var SafetySettingList = []string{
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
//"HARM_CATEGORY_CIVIC_INTEGRITY", This item is deprecated!
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
|
||||
@@ -3,9 +3,9 @@ package gemini
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -77,6 +75,8 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
@@ -97,80 +97,15 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.StringData(c, data)
|
||||
err := helper.StringData(c, data)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||
//helper.Done(c)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -954,14 +954,10 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
responseText := strings.Builder{}
|
||||
func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
finishReason := constant.FinishReasonStop
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
@@ -971,6 +967,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
@@ -982,14 +979,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
@@ -1000,6 +993,45 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 1400
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
if usage.TotalTokens > 0 {
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
if usage.CompletionTokens <= 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
finishReason := constant.FinishReasonStop
|
||||
|
||||
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
@@ -1015,7 +1047,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
|
||||
}
|
||||
finishReason = constant.FinishReasonToolCalls
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
err := handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
@@ -1025,14 +1057,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
response.Choices[0].FinishReason = nil
|
||||
}
|
||||
} else {
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
err := handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = handleStream(c, info, response)
|
||||
err := handleStream(c, info, response)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
@@ -1042,40 +1074,15 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
// 空补全,报错不计费
|
||||
// empty response, throw an error
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
if err != nil {
|
||||
return usage, err
|
||||
}
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := handleFinalStream(c, info, response)
|
||||
if err != nil {
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
handleErr := handleFinalStream(c, info, response)
|
||||
if handleErr != nil {
|
||||
common.SysLog("send final response failed: " + handleErr.Error())
|
||||
}
|
||||
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
// helper.Done(c)
|
||||
//}
|
||||
//resp.Body.Close()
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
usage, err = palmHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
|
||||
@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -30,7 +30,7 @@ type ParamOperation struct {
|
||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||
}
|
||||
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
|
||||
// 尝试断言为操作格式
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
// 使用新方法
|
||||
result, err := applyOperations(string(jsonData), operations)
|
||||
result, err := applyOperations(string(jsonData), operations, conditionContext)
|
||||
return []byte(result), err
|
||||
}
|
||||
|
||||
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
if len(conditions) == 0 {
|
||||
return true, nil // 没有条件,直接通过
|
||||
}
|
||||
results := make([]bool, len(conditions))
|
||||
for i, condition := range conditions {
|
||||
result, err := checkSingleCondition(jsonStr, condition)
|
||||
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
|
||||
}
|
||||
}
|
||||
|
||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
|
||||
// 处理负数索引
|
||||
path := processNegativeIndex(jsonStr, condition.Path)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
if !value.Exists() && contextJSON != "" {
|
||||
value = gjson.Get(contextJSON, condition.Path)
|
||||
}
|
||||
if !value.Exists() {
|
||||
if condition.PassMissingKey {
|
||||
return true, nil
|
||||
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
|
||||
}
|
||||
|
||||
// 利用gjson的类型解析
|
||||
targetBytes, err := json.Marshal(condition.Value)
|
||||
targetBytes, err := common.Marshal(condition.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
||||
}
|
||||
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
|
||||
// applyOperationsLegacy 原参数覆盖方法
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
reqMap := make(map[string]interface{})
|
||||
err := json.Unmarshal(jsonData, &reqMap)
|
||||
err := common.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
|
||||
reqMap[key] = value
|
||||
}
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
return common.Marshal(reqMap)
|
||||
}
|
||||
|
||||
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
|
||||
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
|
||||
var contextJSON string
|
||||
if conditionContext != nil && len(conditionContext) > 0 {
|
||||
ctxBytes, err := common.Marshal(conditionContext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal condition context: %v", err)
|
||||
}
|
||||
contextJSON = string(ctxBytes)
|
||||
}
|
||||
|
||||
result := jsonStr
|
||||
for _, op := range operations {
|
||||
// 检查条件是否满足
|
||||
ok, err := checkConditions(result, op.Conditions, op.Logic)
|
||||
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
var currentMap, newMap map[string]interface{}
|
||||
|
||||
// 解析当前值
|
||||
if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 解析新值
|
||||
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
case map[string]interface{}:
|
||||
newMap = v
|
||||
default:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
jsonBytes, _ := common.Marshal(v)
|
||||
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
}
|
||||
return sjson.Set(jsonStr, path, result)
|
||||
}
|
||||
|
||||
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
|
||||
// 目前内置以下字段:
|
||||
// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
|
||||
// - upstream_model:始终为通道映射后的上游模型名。
|
||||
// - original_model:请求最初指定的模型名。
|
||||
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||
if info == nil || info.ChannelMeta == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := make(map[string]interface{})
|
||||
if info.UpstreamModelName != "" {
|
||||
ctx["model"] = info.UpstreamModelName
|
||||
ctx["upstream_model"] = info.UpstreamModelName
|
||||
}
|
||||
if info.OriginModelName != "" {
|
||||
ctx["original_model"] = info.OriginModelName
|
||||
if _, exists := ctx["model"]; !exists {
|
||||
ctx["model"] = info.OriginModelName
|
||||
}
|
||||
}
|
||||
|
||||
if len(ctx) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// return image.Config, format, clean base64 string, error
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
|
||||
@@ -62,6 +62,12 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
|
||||
isLocalCountTokens := common.GetContextKeyBool(ctx, constant.ContextKeyLocalCountTokens)
|
||||
if isLocalCountTokens {
|
||||
adminInfo["local_count_tokens"] = isLocalCountTokens
|
||||
}
|
||||
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
return other
|
||||
|
||||
@@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
if fileMeta.Detail == "low" && !isPatchBased {
|
||||
return baseTokens, nil
|
||||
}
|
||||
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
if !constant.GetMediaTokenNotStream && !stream {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
@@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||||
fileMeta.Detail = "high"
|
||||
}
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
// Decode image to get dimensions
|
||||
var config image.Config
|
||||
@@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
}
|
||||
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
// 是否统计token
|
||||
if !constant.CountToken {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if meta == nil {
|
||||
return 0, errors.New("token count meta is nil")
|
||||
}
|
||||
|
||||
if !constant.GetMediaToken {
|
||||
return 0, nil
|
||||
}
|
||||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||||
return 0, nil
|
||||
}
|
||||
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
if shouldFetchFiles {
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
// 是否本地计算媒体token数量
|
||||
if !constant.GetMediaToken {
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
// 是否在非流模式下本地计算媒体token数量
|
||||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
if shouldFetchFiles {
|
||||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||
@@ -333,28 +344,28 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -365,7 +376,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 256
|
||||
tkm += 520 // gemini per input image tokens
|
||||
} else {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
|
||||
@@ -16,7 +19,8 @@ import (
|
||||
// return 0, errors.New("unknown relay mode")
|
||||
//}
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
func ResponseText2Usage(c *gin.Context, responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
|
||||
@@ -17,8 +17,7 @@ type GeminiSettings struct {
|
||||
// 默认配置
|
||||
var defaultGeminiSettings = GeminiSettings{
|
||||
SafetySettings: map[string]string{
|
||||
"default": "OFF",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
|
||||
"default": "OFF",
|
||||
},
|
||||
VersionSettings: map[string]string{
|
||||
"default": "v1beta",
|
||||
|
||||
@@ -598,6 +598,11 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
return 2.5 / 0.3, false
|
||||
} else if strings.HasPrefix(name, "gemini-robotics-er-1.5") {
|
||||
return 2.5 / 0.3, false
|
||||
} else if strings.HasPrefix(name, "gemini-3-pro") {
|
||||
if strings.HasPrefix(name, "gemini-3-pro-image") {
|
||||
return 60, false
|
||||
}
|
||||
return 6, false
|
||||
}
|
||||
return 4, false
|
||||
}
|
||||
|
||||
21
setting/system_setting/discord.go
Normal file
21
setting/system_setting/discord.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package system_setting
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/config"
|
||||
|
||||
type DiscordSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultDiscordSettings = DiscordSettings{}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("discord", &defaultDiscordSettings)
|
||||
}
|
||||
|
||||
func GetDiscordSettings() *DiscordSettings {
|
||||
return &defaultDiscordSettings
|
||||
}
|
||||
@@ -192,6 +192,14 @@ function App() {
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/discord'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
||||
<OAuth2Callback type='discord'></OAuth2Callback>
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/oidc'
|
||||
element={
|
||||
|
||||
@@ -30,6 +30,7 @@ import {
|
||||
getSystemName,
|
||||
setUserData,
|
||||
onGitHubOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
prepareCredentialRequestOptions,
|
||||
@@ -53,6 +54,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
|
||||
import LinuxDoIcon from '../common/logo/LinuxDoIcon';
|
||||
import TwoFAVerification from './TwoFAVerification';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { SiDiscord }from 'react-icons/si';
|
||||
|
||||
const LoginForm = () => {
|
||||
let navigate = useNavigate();
|
||||
@@ -73,6 +75,7 @@ const LoginForm = () => {
|
||||
const [showEmailLogin, setShowEmailLogin] = useState(false);
|
||||
const [wechatLoading, setWechatLoading] = useState(false);
|
||||
const [githubLoading, setGithubLoading] = useState(false);
|
||||
const [discordLoading, setDiscordLoading] = useState(false);
|
||||
const [oidcLoading, setOidcLoading] = useState(false);
|
||||
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
|
||||
const [emailLoginLoading, setEmailLoginLoading] = useState(false);
|
||||
@@ -298,6 +301,21 @@ const LoginForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的Discord登录点击处理
|
||||
const handleDiscordClick = () => {
|
||||
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
|
||||
showInfo(t('请先阅读并同意用户协议和隐私政策'));
|
||||
return;
|
||||
}
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的OIDC登录点击处理
|
||||
const handleOIDCClick = () => {
|
||||
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
|
||||
@@ -472,6 +490,19 @@ const LoginForm = () => {
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.discord_oauth && (
|
||||
<Button
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
|
||||
onClick={handleDiscordClick}
|
||||
loading={discordLoading}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 Discord 继续')}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.oidc_enabled && (
|
||||
<Button
|
||||
theme='outline'
|
||||
@@ -714,6 +745,7 @@ const LoginForm = () => {
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
@@ -849,6 +881,7 @@ const LoginForm = () => {
|
||||
{showEmailLogin ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
updateAPI,
|
||||
getSystemName,
|
||||
setUserData,
|
||||
onDiscordOAuthClicked,
|
||||
} from '../../helpers';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { Button, Card, Checkbox, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui';
|
||||
@@ -51,6 +52,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
|
||||
import TelegramLoginButton from 'react-telegram-login/src';
|
||||
import { UserContext } from '../../context/User';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { SiDiscord } from 'react-icons/si';
|
||||
|
||||
const RegisterForm = () => {
|
||||
let navigate = useNavigate();
|
||||
@@ -72,6 +74,7 @@ const RegisterForm = () => {
|
||||
const [showEmailRegister, setShowEmailRegister] = useState(false);
|
||||
const [wechatLoading, setWechatLoading] = useState(false);
|
||||
const [githubLoading, setGithubLoading] = useState(false);
|
||||
const [discordLoading, setDiscordLoading] = useState(false);
|
||||
const [oidcLoading, setOidcLoading] = useState(false);
|
||||
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
|
||||
const [emailRegisterLoading, setEmailRegisterLoading] = useState(false);
|
||||
@@ -264,6 +267,15 @@ const RegisterForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleDiscordClick = () => {
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
} finally {
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOIDCClick = () => {
|
||||
setOidcLoading(true);
|
||||
try {
|
||||
@@ -377,6 +389,19 @@ const RegisterForm = () => {
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.discord_oauth && (
|
||||
<Button
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
|
||||
onClick={handleDiscordClick}
|
||||
loading={discordLoading}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 Discord 继续')}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.oidc_enabled && (
|
||||
<Button
|
||||
theme='outline'
|
||||
@@ -591,6 +616,7 @@ const RegisterForm = () => {
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
@@ -686,6 +712,7 @@ const RegisterForm = () => {
|
||||
{showEmailRegister ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
|
||||
@@ -20,7 +20,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
import React from 'react';
|
||||
import { Button, Dropdown } from '@douyinfe/semi-ui';
|
||||
import { Languages } from 'lucide-react';
|
||||
import { CN, GB, FR, RU, JP } from 'country-flag-icons/react/3x2';
|
||||
import { CN, GB, FR, RU, JP, VN } from 'country-flag-icons/react/3x2';
|
||||
|
||||
const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
|
||||
return (
|
||||
@@ -65,6 +65,13 @@ const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
|
||||
<RU title='Русский' className='!w-5 !h-auto' />
|
||||
<span>Русский</span>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={() => onLanguageChange('vi')}
|
||||
className={`!flex !items-center !gap-2 !px-3 !py-1.5 !text-sm !text-semi-color-text-0 dark:!text-gray-200 ${currentLang === 'vi' ? '!bg-semi-color-primary-light-default dark:!bg-blue-600 !font-semibold' : 'hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-600'}`}
|
||||
>
|
||||
<VN title='Tiếng Việt' className='!w-5 !h-auto' />
|
||||
<span>Tiếng Việt</span>
|
||||
</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
|
||||
@@ -52,6 +52,9 @@ const SystemSetting = () => {
|
||||
GitHubOAuthEnabled: '',
|
||||
GitHubClientId: '',
|
||||
GitHubClientSecret: '',
|
||||
'discord.enabled': '',
|
||||
'discord.client_id': '',
|
||||
'discord.client_secret': '',
|
||||
'oidc.enabled': '',
|
||||
'oidc.client_id': '',
|
||||
'oidc.client_secret': '',
|
||||
@@ -179,6 +182,7 @@ const SystemSetting = () => {
|
||||
case 'EmailAliasRestrictionEnabled':
|
||||
case 'SMTPSSLEnabled':
|
||||
case 'LinuxDOOAuthEnabled':
|
||||
case 'discord.enabled':
|
||||
case 'oidc.enabled':
|
||||
case 'passkey.enabled':
|
||||
case 'passkey.allow_insecure_origin':
|
||||
@@ -473,6 +477,27 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const submitDiscordOAuth = async () => {
|
||||
const options = [];
|
||||
|
||||
if (originInputs['discord.client_id'] !== inputs['discord.client_id']) {
|
||||
options.push({ key: 'discord.client_id', value: inputs['discord.client_id'] });
|
||||
}
|
||||
if (
|
||||
originInputs['discord.client_secret'] !== inputs['discord.client_secret'] &&
|
||||
inputs['discord.client_secret'] !== ''
|
||||
) {
|
||||
options.push({
|
||||
key: 'discord.client_secret',
|
||||
value: inputs['discord.client_secret'],
|
||||
});
|
||||
}
|
||||
|
||||
if (options.length > 0) {
|
||||
await updateOptions(options);
|
||||
}
|
||||
};
|
||||
|
||||
const submitOIDCSettings = async () => {
|
||||
if (inputs['oidc.well_known'] && inputs['oidc.well_known'] !== '') {
|
||||
if (
|
||||
@@ -1014,6 +1039,15 @@ const SystemSetting = () => {
|
||||
>
|
||||
{t('允许通过 GitHub 账户登录 & 注册')}
|
||||
</Form.Checkbox>
|
||||
<Form.Checkbox
|
||||
field='discord.enabled'
|
||||
noLabel
|
||||
onChange={(e) =>
|
||||
handleCheckboxChange('discord.enabled', e)
|
||||
}
|
||||
>
|
||||
{t('允许通过 Discord 账户登录 & 注册')}
|
||||
</Form.Checkbox>
|
||||
<Form.Checkbox
|
||||
field='LinuxDOOAuthEnabled'
|
||||
noLabel
|
||||
@@ -1410,6 +1444,37 @@ const SystemSetting = () => {
|
||||
</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
<Card>
|
||||
<Form.Section text={t('配置 Discord OAuth')}>
|
||||
<Text>{t('用以支持通过 Discord 进行登录注册')}</Text>
|
||||
<Banner
|
||||
type='info'
|
||||
description={`${t('Homepage URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')},${t('Authorization callback URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')}/oauth/discord`}
|
||||
style={{ marginBottom: 20, marginTop: 16 }}
|
||||
/>
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
|
||||
<Form.Input
|
||||
field="['discord.client_id']"
|
||||
label={t('Discord Client ID')}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
|
||||
<Form.Input
|
||||
field="['discord.client_secret']"
|
||||
label={t('Discord Client Secret')}
|
||||
type='password'
|
||||
placeholder={t('敏感信息不会发送到前端显示')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Button onClick={submitDiscordOAuth}>
|
||||
{t('保存 Discord OAuth 设置')}
|
||||
</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
<Card>
|
||||
<Form.Section text={t('配置 Linux DO OAuth')}>
|
||||
<Text>
|
||||
|
||||
@@ -38,13 +38,14 @@ import {
|
||||
IconLock,
|
||||
IconDelete,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { SiTelegram, SiWechat, SiLinux } from 'react-icons/si';
|
||||
import { SiTelegram, SiWechat, SiLinux, SiDiscord } from 'react-icons/si';
|
||||
import { UserPlus, ShieldCheck } from 'lucide-react';
|
||||
import TelegramLoginButton from 'react-telegram-login';
|
||||
import {
|
||||
onGitHubOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
} from '../../../../helpers';
|
||||
import TwoFASetting from '../components/TwoFASetting';
|
||||
|
||||
@@ -247,6 +248,47 @@ const AccountManagement = ({
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* Discord绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
<div className='flex items-center flex-1 min-w-0'>
|
||||
<div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
|
||||
<SiDiscord
|
||||
size={20}
|
||||
className='text-slate-600 dark:text-slate-300'
|
||||
/>
|
||||
</div>
|
||||
<div className='flex-1 min-w-0'>
|
||||
<div className='font-medium text-gray-900'>
|
||||
{t('Discord')}
|
||||
</div>
|
||||
<div className='text-sm text-gray-500 truncate'>
|
||||
{renderAccountInfo(
|
||||
userState.user?.discord_id,
|
||||
t('Discord ID'),
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex-shrink-0'>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='outline'
|
||||
size='small'
|
||||
onClick={() =>
|
||||
onDiscordOAuthClicked(status.discord_client_id)
|
||||
}
|
||||
disabled={
|
||||
isBound(userState.user?.discord_id) ||
|
||||
!status.discord_oauth
|
||||
}
|
||||
>
|
||||
{status.discord_oauth ? t('绑定') : t('未启用')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* OIDC绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
|
||||
@@ -190,6 +190,30 @@ const EditChannelModal = (props) => {
|
||||
const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
|
||||
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
|
||||
const [doubaoApiEditUnlocked, setDoubaoApiEditUnlocked] = useState(false); // 豆包渠道自定义 API 地址隐藏入口
|
||||
const redirectModelList = useMemo(() => {
|
||||
const mapping = inputs.model_mapping;
|
||||
if (typeof mapping !== 'string') return [];
|
||||
const trimmed = mapping.trim();
|
||||
if (!trimmed) return [];
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
if (
|
||||
!parsed ||
|
||||
typeof parsed !== 'object' ||
|
||||
Array.isArray(parsed)
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
const values = Object.values(parsed)
|
||||
.map((value) =>
|
||||
typeof value === 'string' ? value.trim() : undefined,
|
||||
)
|
||||
.filter((value) => value);
|
||||
return Array.from(new Set(values));
|
||||
} catch (error) {
|
||||
return [];
|
||||
}
|
||||
}, [inputs.model_mapping]);
|
||||
|
||||
// 密钥显示状态
|
||||
const [keyDisplayState, setKeyDisplayState] = useState({
|
||||
@@ -220,6 +244,8 @@ const EditChannelModal = (props) => {
|
||||
];
|
||||
const formContainerRef = useRef(null);
|
||||
const doubaoApiClickCountRef = useRef(0);
|
||||
const initialModelsRef = useRef([]);
|
||||
const initialModelMappingRef = useRef('');
|
||||
|
||||
// 2FA状态更新辅助函数
|
||||
const updateTwoFAState = (updates) => {
|
||||
@@ -595,6 +621,10 @@ const EditChannelModal = (props) => {
|
||||
system_prompt: data.system_prompt,
|
||||
system_prompt_override: data.system_prompt_override || false,
|
||||
});
|
||||
initialModelsRef.current = (data.models || [])
|
||||
.map((model) => (model || '').trim())
|
||||
.filter(Boolean);
|
||||
initialModelMappingRef.current = data.model_mapping || '';
|
||||
// console.log(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -830,6 +860,13 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
}, [props.visible, channelId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isEdit) {
|
||||
initialModelsRef.current = [];
|
||||
initialModelMappingRef.current = '';
|
||||
}
|
||||
}, [isEdit, props.visible]);
|
||||
|
||||
// 统一的模态框重置函数
|
||||
const resetModalState = () => {
|
||||
formApiRef.current?.reset();
|
||||
@@ -903,6 +940,80 @@ const EditChannelModal = (props) => {
|
||||
})();
|
||||
};
|
||||
|
||||
const confirmMissingModelMappings = (missingModels) =>
|
||||
new Promise((resolve) => {
|
||||
const modal = Modal.confirm({
|
||||
title: t('模型未加入列表,可能无法调用'),
|
||||
content: (
|
||||
<div className='text-sm leading-6'>
|
||||
<div>
|
||||
{t(
|
||||
'模型重定向里的下列模型尚未添加到“模型”列表,调用时会因为缺少可用模型而失败:',
|
||||
)}
|
||||
</div>
|
||||
<div className='font-mono text-xs break-all text-red-600 mt-1'>
|
||||
{missingModels.join(', ')}
|
||||
</div>
|
||||
<div className='mt-2'>
|
||||
{t(
|
||||
'你可以在“自定义模型名称”处手动添加它们,然后点击填入后再提交,或者直接使用下方操作自动处理。',
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
centered: true,
|
||||
footer: (
|
||||
<Space align='center' className='w-full justify-end'>
|
||||
<Button
|
||||
type='tertiary'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('cancel');
|
||||
}}
|
||||
>
|
||||
{t('返回修改')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='light'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('submit');
|
||||
}}
|
||||
>
|
||||
{t('直接提交')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='solid'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('add');
|
||||
}}
|
||||
>
|
||||
{t('添加后提交')}
|
||||
</Button>
|
||||
</Space>
|
||||
),
|
||||
});
|
||||
});
|
||||
|
||||
const hasModelConfigChanged = (normalizedModels, modelMappingStr) => {
|
||||
if (!isEdit) return true;
|
||||
const initialModels = initialModelsRef.current;
|
||||
if (normalizedModels.length !== initialModels.length) {
|
||||
return true;
|
||||
}
|
||||
for (let i = 0; i < normalizedModels.length; i++) {
|
||||
if (normalizedModels[i] !== initialModels[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
const normalizedMapping = (modelMappingStr || '').trim();
|
||||
const initialMapping = (initialModelMappingRef.current || '').trim();
|
||||
return normalizedMapping !== initialMapping;
|
||||
};
|
||||
|
||||
const submit = async () => {
|
||||
const formValues = formApiRef.current ? formApiRef.current.getValues() : {};
|
||||
let localInputs = { ...formValues };
|
||||
@@ -986,14 +1097,55 @@ const EditChannelModal = (props) => {
|
||||
showInfo(t('请输入API地址!'));
|
||||
return;
|
||||
}
|
||||
if (
|
||||
localInputs.model_mapping &&
|
||||
localInputs.model_mapping !== '' &&
|
||||
!verifyJSON(localInputs.model_mapping)
|
||||
) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
const hasModelMapping =
|
||||
typeof localInputs.model_mapping === 'string' &&
|
||||
localInputs.model_mapping.trim() !== '';
|
||||
let parsedModelMapping = null;
|
||||
if (hasModelMapping) {
|
||||
if (!verifyJSON(localInputs.model_mapping)) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
}
|
||||
try {
|
||||
parsedModelMapping = JSON.parse(localInputs.model_mapping);
|
||||
} catch (error) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const normalizedModels = (localInputs.models || [])
|
||||
.map((model) => (model || '').trim())
|
||||
.filter(Boolean);
|
||||
localInputs.models = normalizedModels;
|
||||
|
||||
if (
|
||||
parsedModelMapping &&
|
||||
typeof parsedModelMapping === 'object' &&
|
||||
!Array.isArray(parsedModelMapping)
|
||||
) {
|
||||
const modelSet = new Set(normalizedModels);
|
||||
const missingModels = Object.keys(parsedModelMapping)
|
||||
.map((key) => (key || '').trim())
|
||||
.filter((key) => key && !modelSet.has(key));
|
||||
const shouldPromptMissing =
|
||||
missingModels.length > 0 &&
|
||||
hasModelConfigChanged(normalizedModels, localInputs.model_mapping);
|
||||
if (shouldPromptMissing) {
|
||||
const confirmAction = await confirmMissingModelMappings(missingModels);
|
||||
if (confirmAction === 'cancel') {
|
||||
return;
|
||||
}
|
||||
if (confirmAction === 'add') {
|
||||
const updatedModels = Array.from(
|
||||
new Set([...normalizedModels, ...missingModels]),
|
||||
);
|
||||
localInputs.models = updatedModels;
|
||||
handleInputChange('models', updatedModels);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
||||
localInputs.base_url = localInputs.base_url.slice(
|
||||
0,
|
||||
@@ -2916,6 +3068,7 @@ const EditChannelModal = (props) => {
|
||||
visible={modelModalVisible}
|
||||
models={fetchedModels}
|
||||
selected={inputs.models}
|
||||
redirectModels={redirectModelList}
|
||||
onConfirm={(selectedModels) => {
|
||||
handleInputChange('models', selectedModels);
|
||||
showSuccess(t('模型列表已更新'));
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import React, { useState, useEffect, useMemo } from 'react';
|
||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
|
||||
import {
|
||||
Modal,
|
||||
@@ -28,12 +28,13 @@ import {
|
||||
Empty,
|
||||
Tabs,
|
||||
Collapse,
|
||||
Tooltip,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
IllustrationNoResultDark,
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { getModelCategories } from '../../../../helpers/render';
|
||||
|
||||
@@ -41,6 +42,7 @@ const ModelSelectModal = ({
|
||||
visible,
|
||||
models = [],
|
||||
selected = [],
|
||||
redirectModels = [],
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}) => {
|
||||
@@ -50,15 +52,54 @@ const ModelSelectModal = ({
|
||||
const [activeTab, setActiveTab] = useState('new');
|
||||
|
||||
const isMobile = useIsMobile();
|
||||
const normalizeModelName = (model) =>
|
||||
typeof model === 'string' ? model.trim() : '';
|
||||
const normalizedRedirectModels = useMemo(
|
||||
() =>
|
||||
Array.from(
|
||||
new Set(
|
||||
(redirectModels || [])
|
||||
.map((model) => normalizeModelName(model))
|
||||
.filter(Boolean),
|
||||
),
|
||||
),
|
||||
[redirectModels],
|
||||
);
|
||||
const normalizedSelectedSet = useMemo(() => {
|
||||
const set = new Set();
|
||||
(selected || []).forEach((model) => {
|
||||
const normalized = normalizeModelName(model);
|
||||
if (normalized) {
|
||||
set.add(normalized);
|
||||
}
|
||||
});
|
||||
return set;
|
||||
}, [selected]);
|
||||
const classificationSet = useMemo(() => {
|
||||
const set = new Set(normalizedSelectedSet);
|
||||
normalizedRedirectModels.forEach((model) => set.add(model));
|
||||
return set;
|
||||
}, [normalizedSelectedSet, normalizedRedirectModels]);
|
||||
const redirectOnlySet = useMemo(() => {
|
||||
const set = new Set();
|
||||
normalizedRedirectModels.forEach((model) => {
|
||||
if (!normalizedSelectedSet.has(model)) {
|
||||
set.add(model);
|
||||
}
|
||||
});
|
||||
return set;
|
||||
}, [normalizedRedirectModels, normalizedSelectedSet]);
|
||||
|
||||
const filteredModels = models.filter((m) =>
|
||||
m.toLowerCase().includes(keyword.toLowerCase()),
|
||||
String(m || '').toLowerCase().includes(keyword.toLowerCase()),
|
||||
);
|
||||
|
||||
// 分类模型:新获取的模型和已有模型
|
||||
const newModels = filteredModels.filter((model) => !selected.includes(model));
|
||||
const isExistingModel = (model) =>
|
||||
classificationSet.has(normalizeModelName(model));
|
||||
const newModels = filteredModels.filter((model) => !isExistingModel(model));
|
||||
const existingModels = filteredModels.filter((model) =>
|
||||
selected.includes(model),
|
||||
isExistingModel(model),
|
||||
);
|
||||
|
||||
// 同步外部选中值
|
||||
@@ -228,7 +269,20 @@ const ModelSelectModal = ({
|
||||
<div className='grid grid-cols-2 gap-x-4'>
|
||||
{categoryData.models.map((model) => (
|
||||
<Checkbox key={model} value={model} className='my-1'>
|
||||
{model}
|
||||
<span className='flex items-center gap-2'>
|
||||
<span>{model}</span>
|
||||
{redirectOnlySet.has(normalizeModelName(model)) && (
|
||||
<Tooltip
|
||||
position='top'
|
||||
content={t('来自模型重定向,尚未加入模型列表')}
|
||||
>
|
||||
<IconInfoCircle
|
||||
size='small'
|
||||
className='text-amber-500 cursor-help'
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
</span>
|
||||
</Checkbox>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -72,6 +72,7 @@ const EditUserModal = (props) => {
|
||||
password: '',
|
||||
github_id: '',
|
||||
oidc_id: '',
|
||||
discord_id: '',
|
||||
wechat_id: '',
|
||||
telegram_id: '',
|
||||
email: '',
|
||||
@@ -332,6 +333,7 @@ const EditUserModal = (props) => {
|
||||
<Row gutter={12}>
|
||||
{[
|
||||
'github_id',
|
||||
'discord_id',
|
||||
'oidc_id',
|
||||
'wechat_id',
|
||||
'email',
|
||||
|
||||
@@ -231,6 +231,17 @@ export async function getOAuthState() {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onDiscordOAuthClicked(client_id) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
const redirect_uri = `${window.location.origin}/oauth/discord`;
|
||||
const response_type = 'code';
|
||||
const scope = 'identify+openid';
|
||||
window.open(
|
||||
`https://discord.com/oauth2/authorize?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`,
|
||||
);
|
||||
}
|
||||
|
||||
export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
|
||||
@@ -482,6 +482,18 @@ export const useLogsData = () => {
|
||||
value: other.request_path,
|
||||
});
|
||||
}
|
||||
if (isAdminUser) {
|
||||
let localCountMode = '';
|
||||
if (other?.admin_info?.local_count_tokens) {
|
||||
localCountMode = t('本地计费');
|
||||
} else {
|
||||
localCountMode = t('上游返回');
|
||||
}
|
||||
expandDataLocal.push({
|
||||
key: t('计费模式'),
|
||||
value: localCountMode,
|
||||
});
|
||||
}
|
||||
expandDatesLocal[logs[i].key] = expandDataLocal;
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import frTranslation from './locales/fr.json';
|
||||
import zhTranslation from './locales/zh.json';
|
||||
import ruTranslation from './locales/ru.json';
|
||||
import jaTranslation from './locales/ja.json';
|
||||
import viTranslation from './locales/vi.json';
|
||||
|
||||
i18n
|
||||
.use(LanguageDetector)
|
||||
@@ -38,6 +39,7 @@ i18n
|
||||
fr: frTranslation,
|
||||
ru: ruTranslation,
|
||||
ja: jaTranslation,
|
||||
vi: viTranslation,
|
||||
},
|
||||
fallbackLng: 'zh',
|
||||
interpolation: {
|
||||
|
||||
2700
web/src/i18n/locales/vi.json
Normal file
2700
web/src/i18n/locales/vi.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -257,6 +257,7 @@
|
||||
"余额充值管理": "余额充值管理",
|
||||
"你似乎并没有修改什么": "你似乎并没有修改什么",
|
||||
"使用 GitHub 继续": "使用 GitHub 继续",
|
||||
"使用 Discord 继续": "使用 Discord 继续",
|
||||
"使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}",
|
||||
"使用 LinuxDO 继续": "使用 LinuxDO 继续",
|
||||
"使用 OIDC 继续": "使用 OIDC 继续",
|
||||
|
||||
Reference in New Issue
Block a user