mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-01 21:15:48 +00:00
Compare commits
23 Commits
v0.9.24
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8aedbb29c3 | ||
|
|
3f19f18dc9 | ||
|
|
a465597e78 | ||
|
|
dbfcb441f7 | ||
|
|
3fb2ba318d | ||
|
|
8f039b3a53 | ||
|
|
c939686509 | ||
|
|
07aff1fe02 | ||
|
|
5f27edcd19 | ||
|
|
f47d473e63 | ||
|
|
7a2bd38700 | ||
|
|
f8c40ecca6 | ||
|
|
2bc991685f | ||
|
|
87811a0493 | ||
|
|
0885597427 | ||
|
|
0952973887 | ||
|
|
6b30f042fa | ||
|
|
efb8f1f5b8 | ||
|
|
de3cf9893d | ||
|
|
fe02e9a066 | ||
|
|
84745d5ca4 | ||
|
|
cdb1c06ad2 | ||
|
|
d9b5748f80 |
@@ -63,7 +63,7 @@
|
||||
# 是否统计图片token
|
||||
# GET_MEDIA_TOKEN=true
|
||||
# 是否在非流(stream=false)情况下统计图片token
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=false
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
|
||||
|
||||
24
.github/workflows/release.yml
vendored
24
.github/workflows/release.yml
vendored
@@ -22,6 +22,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -31,7 +35,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -40,13 +44,11 @@ jobs:
|
||||
- name: Build Backend (amd64)
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-$VERSION
|
||||
- name: Build Backend (arm64)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
VERSION=$(git describe --tags)
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -65,6 +67,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -75,7 +81,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -84,7 +90,6 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-X 'new-api/common.Version=$VERSION'" -o new-api-macos-$VERSION
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -105,6 +110,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -114,7 +123,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -123,7 +132,6 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION'" -o new-api-$VERSION.exe
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -132,5 +140,3 @@ jobs:
|
||||
files: new-api-*.exe
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
|
||||
|
||||
@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
|
||||
|
||||
### 🔐 Authorization and Security
|
||||
|
||||
- 😈 Discord authorization login
|
||||
- 🤖 LinuxDO authorization login
|
||||
- 📱 Telegram authorization login
|
||||
- 🔑 OIDC unified authentication
|
||||
|
||||
@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
|
||||
|
||||
### 🔐 授权与安全
|
||||
|
||||
- 😈 Discord 授权登录
|
||||
- 🤖 LinuxDO 授权登录
|
||||
- 📱 Telegram 授权登录
|
||||
- 🔑 OIDC 统一认证
|
||||
|
||||
@@ -30,6 +30,11 @@ func printHelp() {
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
envVersion := os.Getenv("VERSION")
|
||||
if envVersion != "" {
|
||||
Version = envVersion
|
||||
}
|
||||
|
||||
if *PrintVersion {
|
||||
fmt.Println(Version)
|
||||
os.Exit(0)
|
||||
@@ -111,8 +116,9 @@ func initConstantEnv() {
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
@@ -46,5 +46,7 @@ const (
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
|
||||
ContextKeyLocalCountTokens ContextKey = "local_count_tokens"
|
||||
|
||||
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
var ForceStreamOption bool
|
||||
var CountToken bool
|
||||
var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
|
||||
223
controller/discord.go
Normal file
223
controller/discord.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DiscordResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var discordResponse DiscordResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
common.SysError("Discord 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("Discord 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var discordUser DiscordUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
common.SysError("Discord 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &discordUser, nil
|
||||
}
|
||||
|
||||
func DiscordOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
err := user.FillUserByDiscordId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
user.Username = discordUser.ID
|
||||
} else {
|
||||
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if discordUser.Name != "" {
|
||||
user.DisplayName = discordUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Discord User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func DiscordBind(c *gin.Context) {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Discord 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.DiscordId = discordUser.UID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
@@ -52,6 +52,8 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"discord_oauth": system_setting.GetDiscordSettings().Enabled,
|
||||
"discord_client_id": system_setting.GetDiscordSettings().ClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||
|
||||
@@ -71,6 +71,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "discord.enabled":
|
||||
if option.Value == "true" && system_setting.GetDiscordSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 Discord OAuth,请先填入 Discord Client Id 以及 Discord Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "oidc.enabled":
|
||||
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -453,6 +453,7 @@ func GetSelf(c *gin.Context) {
|
||||
"status": user.Status,
|
||||
"email": user.Email,
|
||||
"github_id": user.GitHubId,
|
||||
"discord_id": user.DiscordId,
|
||||
"oidc_id": user.OidcId,
|
||||
"wechat_id": user.WeChatId,
|
||||
"telegram_id": user.TelegramId,
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||
| GET | /api/oauth/discord | 公开 | Discord 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||
|
||||
@@ -142,7 +142,38 @@ type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
// TODO Conflict with thinkingbudget.
|
||||
// ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
|
||||
ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiThinkingConfig to accept both snake_case and camelCase fields.
|
||||
func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiThinkingConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
|
||||
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
|
||||
ThinkingLevelSnake json.RawMessage `json:"thinking_level,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GeminiThinkingConfig(aux.Alias)
|
||||
|
||||
if aux.IncludeThoughtsSnake != nil {
|
||||
c.IncludeThoughts = *aux.IncludeThoughtsSnake
|
||||
}
|
||||
|
||||
if aux.ThinkingBudgetSnake != nil {
|
||||
c.ThinkingBudget = aux.ThinkingBudgetSnake
|
||||
}
|
||||
|
||||
if len(aux.ThinkingLevelSnake) > 0 {
|
||||
c.ThinkingLevel = aux.ThinkingLevelSnake
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
|
||||
|
||||
@@ -897,6 +897,12 @@ type Reasoning struct {
|
||||
Summary string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type MediaInput struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
@@ -915,7 +921,7 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||
return nil
|
||||
}
|
||||
|
||||
var inputs []MediaInput
|
||||
var mediaInputs []MediaInput
|
||||
|
||||
// Try string first
|
||||
// if str, ok := common.GetJsonType(r.Input); ok {
|
||||
@@ -925,60 +931,74 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||
if common.GetJsonType(r.Input) == "string" {
|
||||
var str string
|
||||
_ = common.Unmarshal(r.Input, &str)
|
||||
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||
return inputs
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
|
||||
return mediaInputs
|
||||
}
|
||||
|
||||
// Try array of parts
|
||||
if common.GetJsonType(r.Input) == "array" {
|
||||
var array []any
|
||||
_ = common.Unmarshal(r.Input, &array)
|
||||
for _, itemAny := range array {
|
||||
// Already parsed MediaInput
|
||||
if media, ok := itemAny.(MediaInput); ok {
|
||||
inputs = append(inputs, media)
|
||||
continue
|
||||
var inputs []Input
|
||||
_ = common.Unmarshal(r.Input, &inputs)
|
||||
for _, input := range inputs {
|
||||
if common.GetJsonType(input.Content) == "string" {
|
||||
var str string
|
||||
_ = common.Unmarshal(input.Content, &str)
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
|
||||
}
|
||||
// Generic map
|
||||
item, ok := itemAny.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeVal, ok := item["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch typeVal {
|
||||
case "input_text":
|
||||
text, _ := item["text"].(string)
|
||||
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
|
||||
case "input_image":
|
||||
// image_url may be string or object with url field
|
||||
var imageUrl string
|
||||
switch v := item["image_url"].(type) {
|
||||
case string:
|
||||
imageUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
imageUrl = url
|
||||
|
||||
if common.GetJsonType(input.Content) == "array" {
|
||||
var array []any
|
||||
_ = common.Unmarshal(input.Content, &array)
|
||||
for _, itemAny := range array {
|
||||
// Already parsed MediaContent
|
||||
if media, ok := itemAny.(MediaInput); ok {
|
||||
mediaInputs = append(mediaInputs, media)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generic map
|
||||
item, ok := itemAny.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
typeVal, ok := item["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch typeVal {
|
||||
case "input_text":
|
||||
text, _ := item["text"].(string)
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: text})
|
||||
case "input_image":
|
||||
// image_url may be string or object with url field
|
||||
var imageUrl string
|
||||
switch v := item["image_url"].(type) {
|
||||
case string:
|
||||
imageUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
imageUrl = url
|
||||
}
|
||||
}
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||
case "input_file":
|
||||
// file_url may be string or object with url field
|
||||
var fileUrl string
|
||||
switch v := item["file_url"].(type) {
|
||||
case string:
|
||||
fileUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
fileUrl = url
|
||||
}
|
||||
}
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||
}
|
||||
}
|
||||
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||
case "input_file":
|
||||
// file_url may be string or object with url field
|
||||
var fileUrl string
|
||||
switch v := item["file_url"].(type) {
|
||||
case string:
|
||||
fileUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
fileUrl = url
|
||||
}
|
||||
}
|
||||
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputs
|
||||
return mediaInputs
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
|
||||
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
@@ -539,6 +540,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByDiscordId() error {
|
||||
if user.DiscordId == "" {
|
||||
return errors.New("discord id 为空!")
|
||||
}
|
||||
DB.Where(User{DiscordId: user.DiscordId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
@@ -578,6 +587,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsDiscordIdAlreadyTaken(discordId string) bool {
|
||||
return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
@@ -682,7 +682,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
|
||||
@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
||||
helper.Done(c)
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
|
||||
@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return usage, nil
|
||||
|
||||
@@ -32,7 +32,7 @@ var SafetySettingList = []string{
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
//"HARM_CATEGORY_CIVIC_INTEGRITY", This item is deprecated!
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
|
||||
@@ -3,9 +3,9 @@ package gemini
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -77,6 +75,8 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
@@ -97,80 +97,15 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.StringData(c, data)
|
||||
err := helper.StringData(c, data)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||
//helper.Done(c)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -954,14 +954,10 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
responseText := strings.Builder{}
|
||||
func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
finishReason := constant.FinishReasonStop
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
@@ -971,6 +967,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
@@ -982,14 +979,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
@@ -1000,6 +993,45 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 1400
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
if usage.TotalTokens > 0 {
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
if usage.CompletionTokens <= 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
finishReason := constant.FinishReasonStop
|
||||
|
||||
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
@@ -1015,7 +1047,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
|
||||
}
|
||||
finishReason = constant.FinishReasonToolCalls
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
err := handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
@@ -1025,14 +1057,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
response.Choices[0].FinishReason = nil
|
||||
}
|
||||
} else {
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
err := handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = handleStream(c, info, response)
|
||||
err := handleStream(c, info, response)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
@@ -1042,40 +1074,15 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
// 空补全,报错不计费
|
||||
// empty response, throw an error
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
if err != nil {
|
||||
return usage, err
|
||||
}
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := handleFinalStream(c, info, response)
|
||||
if err != nil {
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
handleErr := handleFinalStream(c, info, response)
|
||||
if handleErr != nil {
|
||||
common.SysLog("send final response failed: " + handleErr.Error())
|
||||
}
|
||||
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
// helper.Done(c)
|
||||
//}
|
||||
//resp.Body.Close()
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -18,10 +19,26 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
// HandleStreamFormat processes a streaming response payload according to the provided RelayInfo and forwards it to the appropriate format-specific handler.
|
||||
//
|
||||
// It increments info.SendResponseCount, optionally converts OpenRouter "reasoning" fields to "reasoning_content" when the channel is OpenRouter and OpenRouterConvertToOpenAI is enabled, and then dispatches the (possibly modified) JSON string to the handler for the configured RelayFormat (OpenAI, Claude, or Gemini). It returns any error produced by the selected handler or nil if no handler is invoked.
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
// OpenRouter reasoning 字段转换:reasoning -> reasoning_content
|
||||
// 仅当启用转换为OpenAI兼容格式时执行
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.OpenRouterConvertToOpenAI {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err == nil {
|
||||
convertOpenRouterReasoningFieldsStream(&streamResponse)
|
||||
// 重新序列化为JSON
|
||||
newData, err := common.Marshal(streamResponse)
|
||||
if err == nil {
|
||||
data = string(newData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
@@ -253,9 +270,26 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponsesStreamData sends a non-empty data chunk for the given stream response to the client.
|
||||
// If data is empty, it returns without sending anything.
|
||||
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
|
||||
if data == "" {
|
||||
return
|
||||
}
|
||||
helper.ResponseChunkData(c, streamResponse, data)
|
||||
}
|
||||
|
||||
// convertOpenRouterReasoningFieldsStream converts each choice's `Delta` in a streaming ChatCompletions response
|
||||
// by normalizing any `reasoning` fields into `reasoning_content`.
|
||||
// It applies ConvertReasoningField to every choice's Delta and is a no-op if `response` is nil or has no choices.
|
||||
func convertOpenRouterReasoningFieldsStream(response *dto.ChatCompletionsStreamResponse) {
|
||||
if response == nil || len(response.Choices) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有choices,对每个Delta使用统一的泛型函数进行转换
|
||||
for i := range response.Choices {
|
||||
choice := &response.Choices[i]
|
||||
ConvertReasoningField(&choice.Delta)
|
||||
}
|
||||
}
|
||||
35
relay/channel/openai/reasoning_converter.go
Normal file
35
relay/channel/openai/reasoning_converter.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package openai
|
||||
|
||||
// ReasoningHolder 定义一个通用的接口,用于操作包含reasoning字段的结构体
|
||||
type ReasoningHolder interface {
|
||||
// 获取reasoning字段的值
|
||||
GetReasoning() string
|
||||
// 设置reasoning字段的值
|
||||
SetReasoning(reasoning string)
|
||||
// 获取reasoning_content字段的值
|
||||
GetReasoningContent() string
|
||||
// 设置reasoning_content字段的值
|
||||
SetReasoningContent(reasoningContent string)
|
||||
}
|
||||
|
||||
// ConvertReasoningField 通用的reasoning字段转换函数
|
||||
// 将reasoning字段的内容移动到reasoning_content字段
|
||||
// ConvertReasoningField moves the holder's reasoning into its reasoning content and clears the original reasoning field.
|
||||
// If GetReasoning returns an empty string, the holder is unchanged. When clearing, types that implement SetReasoningToNil()
|
||||
// will have that method invoked; otherwise SetReasoning("") is used.
|
||||
func ConvertReasoningField[T ReasoningHolder](holder T) {
|
||||
reasoning := holder.GetReasoning()
|
||||
if reasoning != "" {
|
||||
holder.SetReasoningContent(reasoning)
|
||||
}
|
||||
|
||||
// 使用类型断言来智能清理reasoning字段
|
||||
switch h := any(holder).(type) {
|
||||
case interface{ SetReasoningToNil() }:
|
||||
// 流式响应:指针类型,设为nil
|
||||
h.SetReasoningToNil()
|
||||
default:
|
||||
// 非流式响应:值类型,设为空字符串
|
||||
holder.SetReasoning("")
|
||||
}
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
@@ -194,6 +194,25 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// OpenaiHandler processes an upstream OpenAI-like HTTP response, normalizes or infers token usage,
|
||||
// optionally converts OpenRouter reasoning fields to OpenAI-compatible `reasoning_content`, adapts
|
||||
// the response to the configured relay format (OpenAI, Claude, or Gemini), writes the final body
|
||||
// to the client, and returns the computed usage.
|
||||
//
|
||||
// It will:
|
||||
// - Handle OpenRouter enterprise wrapper responses when the channel is OpenRouter Enterprise.
|
||||
// - Unmarshal the upstream body into an internal simple response and, when configured,
|
||||
// convert OpenRouter `reasoning` fields into `reasoning_content`.
|
||||
// - If usage prompt tokens are missing, infer completion tokens by counting tokens in choices
|
||||
// (falling back to per-choice text token counting) and set Prompt/Completion/Total tokens.
|
||||
// - Apply channel-specific post-processing to usage (cached token adjustments).
|
||||
// - Depending on RelayFormat and channel settings, inject updated usage into the body,
|
||||
// reserialize the converted simple response when ForceFormat is enabled or when OpenRouter
|
||||
// conversion was applied, or convert the response to Claude/Gemini formats.
|
||||
// - Write the final response body to the client via a graceful copy helper.
|
||||
//
|
||||
// Returns the final usage (possibly inferred or modified) or a NewAPIError describing any failure
|
||||
// encountered while reading, parsing, or transforming the upstream response.
|
||||
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
@@ -226,6 +245,12 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// OpenRouter reasoning 字段转换:reasoning -> reasoning_content
|
||||
// 仅当启用转换为OpenAI兼容格式时执行(修改现有无条件转换)
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.OpenRouterConvertToOpenAI {
|
||||
convertOpenRouterReasoningFields(&simpleResponse)
|
||||
}
|
||||
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
@@ -271,6 +296,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
} else {
|
||||
// 对于 OpenRouter,仅在执行转换后重新序列化
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.OpenRouterConvertToOpenAI {
|
||||
responseBody, err = common.Marshal(simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case types.RelayFormatClaude:
|
||||
@@ -672,6 +704,10 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
}
|
||||
}
|
||||
|
||||
// extractCachedTokensFromBody extracts a cached token count from a JSON response body.
|
||||
// It looks for cached token values in the following fields (in order): `usage.prompt_tokens_details.cached_tokens`,
|
||||
// `usage.cached_tokens`, and `usage.prompt_cache_hit_tokens`. It returns the first found value and `true`;
|
||||
// if none are present or the body cannot be parsed, it returns 0 and `false`.
|
||||
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
@@ -702,3 +738,18 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// convertOpenRouterReasoningFields 转换OpenRouter响应中的reasoning字段为reasoning_content
|
||||
// convertOpenRouterReasoningFields converts OpenRouter-style `reasoning` fields into `reasoning_content` for every choice's message in the provided OpenAITextResponse.
|
||||
// It modifies the response in place and is a no-op if `response` is nil or contains no choices.
|
||||
func convertOpenRouterReasoningFields(response *dto.OpenAITextResponse) {
|
||||
if response == nil || len(response.Choices) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有choices,对每个Message使用统一的泛型函数进行转换
|
||||
for i := range response.Choices {
|
||||
choice := &response.Choices[i]
|
||||
ConvertReasoningField(&choice.Message)
|
||||
}
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
usage, err = palmHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
|
||||
@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -30,7 +30,7 @@ type ParamOperation struct {
|
||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||
}
|
||||
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
|
||||
// 尝试断言为操作格式
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
// 使用新方法
|
||||
result, err := applyOperations(string(jsonData), operations)
|
||||
result, err := applyOperations(string(jsonData), operations, conditionContext)
|
||||
return []byte(result), err
|
||||
}
|
||||
|
||||
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
if len(conditions) == 0 {
|
||||
return true, nil // 没有条件,直接通过
|
||||
}
|
||||
results := make([]bool, len(conditions))
|
||||
for i, condition := range conditions {
|
||||
result, err := checkSingleCondition(jsonStr, condition)
|
||||
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
|
||||
}
|
||||
}
|
||||
|
||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
|
||||
// 处理负数索引
|
||||
path := processNegativeIndex(jsonStr, condition.Path)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
if !value.Exists() && contextJSON != "" {
|
||||
value = gjson.Get(contextJSON, condition.Path)
|
||||
}
|
||||
if !value.Exists() {
|
||||
if condition.PassMissingKey {
|
||||
return true, nil
|
||||
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
|
||||
}
|
||||
|
||||
// 利用gjson的类型解析
|
||||
targetBytes, err := json.Marshal(condition.Value)
|
||||
targetBytes, err := common.Marshal(condition.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
||||
}
|
||||
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
|
||||
// applyOperationsLegacy 原参数覆盖方法
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
reqMap := make(map[string]interface{})
|
||||
err := json.Unmarshal(jsonData, &reqMap)
|
||||
err := common.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
|
||||
reqMap[key] = value
|
||||
}
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
return common.Marshal(reqMap)
|
||||
}
|
||||
|
||||
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
|
||||
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
|
||||
var contextJSON string
|
||||
if conditionContext != nil && len(conditionContext) > 0 {
|
||||
ctxBytes, err := common.Marshal(conditionContext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal condition context: %v", err)
|
||||
}
|
||||
contextJSON = string(ctxBytes)
|
||||
}
|
||||
|
||||
result := jsonStr
|
||||
for _, op := range operations {
|
||||
// 检查条件是否满足
|
||||
ok, err := checkConditions(result, op.Conditions, op.Logic)
|
||||
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
var currentMap, newMap map[string]interface{}
|
||||
|
||||
// 解析当前值
|
||||
if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 解析新值
|
||||
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
case map[string]interface{}:
|
||||
newMap = v
|
||||
default:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
jsonBytes, _ := common.Marshal(v)
|
||||
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
}
|
||||
return sjson.Set(jsonStr, path, result)
|
||||
}
|
||||
|
||||
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
|
||||
// 目前内置以下字段:
|
||||
// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
|
||||
// - upstream_model:始终为通道映射后的上游模型名。
|
||||
// - original_model:请求最初指定的模型名。
|
||||
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||
if info == nil || info.ChannelMeta == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := make(map[string]interface{})
|
||||
if info.UpstreamModelName != "" {
|
||||
ctx["model"] = info.UpstreamModelName
|
||||
ctx["upstream_model"] = info.UpstreamModelName
|
||||
}
|
||||
if info.OriginModelName != "" {
|
||||
ctx["original_model"] = info.OriginModelName
|
||||
if _, exists := ctx["model"]; !exists {
|
||||
ctx["model"] = info.OriginModelName
|
||||
}
|
||||
}
|
||||
|
||||
if len(ctx) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// return image.Config, format, clean base64 string, error
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
|
||||
@@ -62,6 +62,12 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
|
||||
isLocalCountTokens := common.GetContextKeyBool(ctx, constant.ContextKeyLocalCountTokens)
|
||||
if isLocalCountTokens {
|
||||
adminInfo["local_count_tokens"] = isLocalCountTokens
|
||||
}
|
||||
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
return other
|
||||
|
||||
@@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
if fileMeta.Detail == "low" && !isPatchBased {
|
||||
return baseTokens, nil
|
||||
}
|
||||
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
if !constant.GetMediaTokenNotStream && !stream {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
@@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||||
fileMeta.Detail = "high"
|
||||
}
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
// Decode image to get dimensions
|
||||
var config image.Config
|
||||
@@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
}
|
||||
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
// 是否统计token
|
||||
if !constant.CountToken {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if meta == nil {
|
||||
return 0, errors.New("token count meta is nil")
|
||||
}
|
||||
|
||||
if !constant.GetMediaToken {
|
||||
return 0, nil
|
||||
}
|
||||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||||
return 0, nil
|
||||
}
|
||||
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
if shouldFetchFiles {
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
// 是否本地计算媒体token数量
|
||||
if !constant.GetMediaToken {
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
// 是否在非流模式下本地计算媒体token数量
|
||||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
if shouldFetchFiles {
|
||||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||
@@ -333,28 +344,28 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -365,7 +376,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 256
|
||||
tkm += 520 // gemini per input image tokens
|
||||
} else {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
|
||||
@@ -16,7 +19,8 @@ import (
|
||||
// return 0, errors.New("unknown relay mode")
|
||||
//}
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
func ResponseText2Usage(c *gin.Context, responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
|
||||
@@ -17,8 +17,7 @@ type GeminiSettings struct {
|
||||
// 默认配置
|
||||
var defaultGeminiSettings = GeminiSettings{
|
||||
SafetySettings: map[string]string{
|
||||
"default": "OFF",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
|
||||
"default": "OFF",
|
||||
},
|
||||
VersionSettings: map[string]string{
|
||||
"default": "v1beta",
|
||||
|
||||
@@ -598,6 +598,11 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
return 2.5 / 0.3, false
|
||||
} else if strings.HasPrefix(name, "gemini-robotics-er-1.5") {
|
||||
return 2.5 / 0.3, false
|
||||
} else if strings.HasPrefix(name, "gemini-3-pro") {
|
||||
if strings.HasPrefix(name, "gemini-3-pro-image") {
|
||||
return 60, false
|
||||
}
|
||||
return 6, false
|
||||
}
|
||||
return 4, false
|
||||
}
|
||||
|
||||
21
setting/system_setting/discord.go
Normal file
21
setting/system_setting/discord.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package system_setting
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/config"
|
||||
|
||||
type DiscordSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultDiscordSettings = DiscordSettings{}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("discord", &defaultDiscordSettings)
|
||||
}
|
||||
|
||||
func GetDiscordSettings() *DiscordSettings {
|
||||
return &defaultDiscordSettings
|
||||
}
|
||||
@@ -192,6 +192,14 @@ function App() {
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/discord'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
||||
<OAuth2Callback type='discord'></OAuth2Callback>
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/oidc'
|
||||
element={
|
||||
|
||||
@@ -30,6 +30,7 @@ import {
|
||||
getSystemName,
|
||||
setUserData,
|
||||
onGitHubOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
prepareCredentialRequestOptions,
|
||||
@@ -53,6 +54,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
|
||||
import LinuxDoIcon from '../common/logo/LinuxDoIcon';
|
||||
import TwoFAVerification from './TwoFAVerification';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { SiDiscord }from 'react-icons/si';
|
||||
|
||||
const LoginForm = () => {
|
||||
let navigate = useNavigate();
|
||||
@@ -73,6 +75,7 @@ const LoginForm = () => {
|
||||
const [showEmailLogin, setShowEmailLogin] = useState(false);
|
||||
const [wechatLoading, setWechatLoading] = useState(false);
|
||||
const [githubLoading, setGithubLoading] = useState(false);
|
||||
const [discordLoading, setDiscordLoading] = useState(false);
|
||||
const [oidcLoading, setOidcLoading] = useState(false);
|
||||
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
|
||||
const [emailLoginLoading, setEmailLoginLoading] = useState(false);
|
||||
@@ -298,6 +301,21 @@ const LoginForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的Discord登录点击处理
|
||||
const handleDiscordClick = () => {
|
||||
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
|
||||
showInfo(t('请先阅读并同意用户协议和隐私政策'));
|
||||
return;
|
||||
}
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的OIDC登录点击处理
|
||||
const handleOIDCClick = () => {
|
||||
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
|
||||
@@ -472,6 +490,19 @@ const LoginForm = () => {
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.discord_oauth && (
|
||||
<Button
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
|
||||
onClick={handleDiscordClick}
|
||||
loading={discordLoading}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 Discord 继续')}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.oidc_enabled && (
|
||||
<Button
|
||||
theme='outline'
|
||||
@@ -714,6 +745,7 @@ const LoginForm = () => {
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
@@ -849,6 +881,7 @@ const LoginForm = () => {
|
||||
{showEmailLogin ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
updateAPI,
|
||||
getSystemName,
|
||||
setUserData,
|
||||
onDiscordOAuthClicked,
|
||||
} from '../../helpers';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { Button, Card, Checkbox, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui';
|
||||
@@ -51,6 +52,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
|
||||
import TelegramLoginButton from 'react-telegram-login/src';
|
||||
import { UserContext } from '../../context/User';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { SiDiscord } from 'react-icons/si';
|
||||
|
||||
const RegisterForm = () => {
|
||||
let navigate = useNavigate();
|
||||
@@ -72,6 +74,7 @@ const RegisterForm = () => {
|
||||
const [showEmailRegister, setShowEmailRegister] = useState(false);
|
||||
const [wechatLoading, setWechatLoading] = useState(false);
|
||||
const [githubLoading, setGithubLoading] = useState(false);
|
||||
const [discordLoading, setDiscordLoading] = useState(false);
|
||||
const [oidcLoading, setOidcLoading] = useState(false);
|
||||
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
|
||||
const [emailRegisterLoading, setEmailRegisterLoading] = useState(false);
|
||||
@@ -264,6 +267,15 @@ const RegisterForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleDiscordClick = () => {
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
} finally {
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOIDCClick = () => {
|
||||
setOidcLoading(true);
|
||||
try {
|
||||
@@ -377,6 +389,19 @@ const RegisterForm = () => {
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.discord_oauth && (
|
||||
<Button
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
|
||||
onClick={handleDiscordClick}
|
||||
loading={discordLoading}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 Discord 继续')}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.oidc_enabled && (
|
||||
<Button
|
||||
theme='outline'
|
||||
@@ -591,6 +616,7 @@ const RegisterForm = () => {
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
@@ -686,6 +712,7 @@ const RegisterForm = () => {
|
||||
{showEmailRegister ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
|
||||
@@ -20,7 +20,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
import React from 'react';
|
||||
import { Button, Dropdown } from '@douyinfe/semi-ui';
|
||||
import { Languages } from 'lucide-react';
|
||||
import { CN, GB, FR, RU, JP } from 'country-flag-icons/react/3x2';
|
||||
import { CN, GB, FR, RU, JP, VN } from 'country-flag-icons/react/3x2';
|
||||
|
||||
const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
|
||||
return (
|
||||
@@ -65,6 +65,13 @@ const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
|
||||
<RU title='Русский' className='!w-5 !h-auto' />
|
||||
<span>Русский</span>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={() => onLanguageChange('vi')}
|
||||
className={`!flex !items-center !gap-2 !px-3 !py-1.5 !text-sm !text-semi-color-text-0 dark:!text-gray-200 ${currentLang === 'vi' ? '!bg-semi-color-primary-light-default dark:!bg-blue-600 !font-semibold' : 'hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-600'}`}
|
||||
>
|
||||
<VN title='Tiếng Việt' className='!w-5 !h-auto' />
|
||||
<span>Tiếng Việt</span>
|
||||
</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
|
||||
@@ -52,6 +52,9 @@ const SystemSetting = () => {
|
||||
GitHubOAuthEnabled: '',
|
||||
GitHubClientId: '',
|
||||
GitHubClientSecret: '',
|
||||
'discord.enabled': '',
|
||||
'discord.client_id': '',
|
||||
'discord.client_secret': '',
|
||||
'oidc.enabled': '',
|
||||
'oidc.client_id': '',
|
||||
'oidc.client_secret': '',
|
||||
@@ -179,6 +182,7 @@ const SystemSetting = () => {
|
||||
case 'EmailAliasRestrictionEnabled':
|
||||
case 'SMTPSSLEnabled':
|
||||
case 'LinuxDOOAuthEnabled':
|
||||
case 'discord.enabled':
|
||||
case 'oidc.enabled':
|
||||
case 'passkey.enabled':
|
||||
case 'passkey.allow_insecure_origin':
|
||||
@@ -473,6 +477,27 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const submitDiscordOAuth = async () => {
|
||||
const options = [];
|
||||
|
||||
if (originInputs['discord.client_id'] !== inputs['discord.client_id']) {
|
||||
options.push({ key: 'discord.client_id', value: inputs['discord.client_id'] });
|
||||
}
|
||||
if (
|
||||
originInputs['discord.client_secret'] !== inputs['discord.client_secret'] &&
|
||||
inputs['discord.client_secret'] !== ''
|
||||
) {
|
||||
options.push({
|
||||
key: 'discord.client_secret',
|
||||
value: inputs['discord.client_secret'],
|
||||
});
|
||||
}
|
||||
|
||||
if (options.length > 0) {
|
||||
await updateOptions(options);
|
||||
}
|
||||
};
|
||||
|
||||
const submitOIDCSettings = async () => {
|
||||
if (inputs['oidc.well_known'] && inputs['oidc.well_known'] !== '') {
|
||||
if (
|
||||
@@ -1014,6 +1039,15 @@ const SystemSetting = () => {
|
||||
>
|
||||
{t('允许通过 GitHub 账户登录 & 注册')}
|
||||
</Form.Checkbox>
|
||||
<Form.Checkbox
|
||||
field='discord.enabled'
|
||||
noLabel
|
||||
onChange={(e) =>
|
||||
handleCheckboxChange('discord.enabled', e)
|
||||
}
|
||||
>
|
||||
{t('允许通过 Discord 账户登录 & 注册')}
|
||||
</Form.Checkbox>
|
||||
<Form.Checkbox
|
||||
field='LinuxDOOAuthEnabled'
|
||||
noLabel
|
||||
@@ -1410,6 +1444,37 @@ const SystemSetting = () => {
|
||||
</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
<Card>
|
||||
<Form.Section text={t('配置 Discord OAuth')}>
|
||||
<Text>{t('用以支持通过 Discord 进行登录注册')}</Text>
|
||||
<Banner
|
||||
type='info'
|
||||
description={`${t('Homepage URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')},${t('Authorization callback URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')}/oauth/discord`}
|
||||
style={{ marginBottom: 20, marginTop: 16 }}
|
||||
/>
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
|
||||
<Form.Input
|
||||
field="['discord.client_id']"
|
||||
label={t('Discord Client ID')}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
|
||||
<Form.Input
|
||||
field="['discord.client_secret']"
|
||||
label={t('Discord Client Secret')}
|
||||
type='password'
|
||||
placeholder={t('敏感信息不会发送到前端显示')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Button onClick={submitDiscordOAuth}>
|
||||
{t('保存 Discord OAuth 设置')}
|
||||
</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
<Card>
|
||||
<Form.Section text={t('配置 Linux DO OAuth')}>
|
||||
<Text>
|
||||
|
||||
@@ -38,13 +38,14 @@ import {
|
||||
IconLock,
|
||||
IconDelete,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { SiTelegram, SiWechat, SiLinux } from 'react-icons/si';
|
||||
import { SiTelegram, SiWechat, SiLinux, SiDiscord } from 'react-icons/si';
|
||||
import { UserPlus, ShieldCheck } from 'lucide-react';
|
||||
import TelegramLoginButton from 'react-telegram-login';
|
||||
import {
|
||||
onGitHubOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
} from '../../../../helpers';
|
||||
import TwoFASetting from '../components/TwoFASetting';
|
||||
|
||||
@@ -247,6 +248,47 @@ const AccountManagement = ({
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* Discord绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
<div className='flex items-center flex-1 min-w-0'>
|
||||
<div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
|
||||
<SiDiscord
|
||||
size={20}
|
||||
className='text-slate-600 dark:text-slate-300'
|
||||
/>
|
||||
</div>
|
||||
<div className='flex-1 min-w-0'>
|
||||
<div className='font-medium text-gray-900'>
|
||||
{t('Discord')}
|
||||
</div>
|
||||
<div className='text-sm text-gray-500 truncate'>
|
||||
{renderAccountInfo(
|
||||
userState.user?.discord_id,
|
||||
t('Discord ID'),
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex-shrink-0'>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='outline'
|
||||
size='small'
|
||||
onClick={() =>
|
||||
onDiscordOAuthClicked(status.discord_client_id)
|
||||
}
|
||||
disabled={
|
||||
isBound(userState.user?.discord_id) ||
|
||||
!status.discord_oauth
|
||||
}
|
||||
>
|
||||
{status.discord_oauth ? t('绑定') : t('未启用')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* OIDC绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
|
||||
@@ -190,6 +190,30 @@ const EditChannelModal = (props) => {
|
||||
const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
|
||||
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
|
||||
const [doubaoApiEditUnlocked, setDoubaoApiEditUnlocked] = useState(false); // 豆包渠道自定义 API 地址隐藏入口
|
||||
const redirectModelList = useMemo(() => {
|
||||
const mapping = inputs.model_mapping;
|
||||
if (typeof mapping !== 'string') return [];
|
||||
const trimmed = mapping.trim();
|
||||
if (!trimmed) return [];
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
if (
|
||||
!parsed ||
|
||||
typeof parsed !== 'object' ||
|
||||
Array.isArray(parsed)
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
const values = Object.values(parsed)
|
||||
.map((value) =>
|
||||
typeof value === 'string' ? value.trim() : undefined,
|
||||
)
|
||||
.filter((value) => value);
|
||||
return Array.from(new Set(values));
|
||||
} catch (error) {
|
||||
return [];
|
||||
}
|
||||
}, [inputs.model_mapping]);
|
||||
|
||||
// 密钥显示状态
|
||||
const [keyDisplayState, setKeyDisplayState] = useState({
|
||||
@@ -220,6 +244,8 @@ const EditChannelModal = (props) => {
|
||||
];
|
||||
const formContainerRef = useRef(null);
|
||||
const doubaoApiClickCountRef = useRef(0);
|
||||
const initialModelsRef = useRef([]);
|
||||
const initialModelMappingRef = useRef('');
|
||||
|
||||
// 2FA状态更新辅助函数
|
||||
const updateTwoFAState = (updates) => {
|
||||
@@ -595,6 +621,10 @@ const EditChannelModal = (props) => {
|
||||
system_prompt: data.system_prompt,
|
||||
system_prompt_override: data.system_prompt_override || false,
|
||||
});
|
||||
initialModelsRef.current = (data.models || [])
|
||||
.map((model) => (model || '').trim())
|
||||
.filter(Boolean);
|
||||
initialModelMappingRef.current = data.model_mapping || '';
|
||||
// console.log(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -830,6 +860,13 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
}, [props.visible, channelId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isEdit) {
|
||||
initialModelsRef.current = [];
|
||||
initialModelMappingRef.current = '';
|
||||
}
|
||||
}, [isEdit, props.visible]);
|
||||
|
||||
// 统一的模态框重置函数
|
||||
const resetModalState = () => {
|
||||
formApiRef.current?.reset();
|
||||
@@ -903,6 +940,80 @@ const EditChannelModal = (props) => {
|
||||
})();
|
||||
};
|
||||
|
||||
const confirmMissingModelMappings = (missingModels) =>
|
||||
new Promise((resolve) => {
|
||||
const modal = Modal.confirm({
|
||||
title: t('模型未加入列表,可能无法调用'),
|
||||
content: (
|
||||
<div className='text-sm leading-6'>
|
||||
<div>
|
||||
{t(
|
||||
'模型重定向里的下列模型尚未添加到“模型”列表,调用时会因为缺少可用模型而失败:',
|
||||
)}
|
||||
</div>
|
||||
<div className='font-mono text-xs break-all text-red-600 mt-1'>
|
||||
{missingModels.join(', ')}
|
||||
</div>
|
||||
<div className='mt-2'>
|
||||
{t(
|
||||
'你可以在“自定义模型名称”处手动添加它们,然后点击填入后再提交,或者直接使用下方操作自动处理。',
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
centered: true,
|
||||
footer: (
|
||||
<Space align='center' className='w-full justify-end'>
|
||||
<Button
|
||||
type='tertiary'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('cancel');
|
||||
}}
|
||||
>
|
||||
{t('返回修改')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='light'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('submit');
|
||||
}}
|
||||
>
|
||||
{t('直接提交')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='solid'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('add');
|
||||
}}
|
||||
>
|
||||
{t('添加后提交')}
|
||||
</Button>
|
||||
</Space>
|
||||
),
|
||||
});
|
||||
});
|
||||
|
||||
const hasModelConfigChanged = (normalizedModels, modelMappingStr) => {
|
||||
if (!isEdit) return true;
|
||||
const initialModels = initialModelsRef.current;
|
||||
if (normalizedModels.length !== initialModels.length) {
|
||||
return true;
|
||||
}
|
||||
for (let i = 0; i < normalizedModels.length; i++) {
|
||||
if (normalizedModels[i] !== initialModels[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
const normalizedMapping = (modelMappingStr || '').trim();
|
||||
const initialMapping = (initialModelMappingRef.current || '').trim();
|
||||
return normalizedMapping !== initialMapping;
|
||||
};
|
||||
|
||||
const submit = async () => {
|
||||
const formValues = formApiRef.current ? formApiRef.current.getValues() : {};
|
||||
let localInputs = { ...formValues };
|
||||
@@ -986,14 +1097,55 @@ const EditChannelModal = (props) => {
|
||||
showInfo(t('请输入API地址!'));
|
||||
return;
|
||||
}
|
||||
if (
|
||||
localInputs.model_mapping &&
|
||||
localInputs.model_mapping !== '' &&
|
||||
!verifyJSON(localInputs.model_mapping)
|
||||
) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
const hasModelMapping =
|
||||
typeof localInputs.model_mapping === 'string' &&
|
||||
localInputs.model_mapping.trim() !== '';
|
||||
let parsedModelMapping = null;
|
||||
if (hasModelMapping) {
|
||||
if (!verifyJSON(localInputs.model_mapping)) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
}
|
||||
try {
|
||||
parsedModelMapping = JSON.parse(localInputs.model_mapping);
|
||||
} catch (error) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const normalizedModels = (localInputs.models || [])
|
||||
.map((model) => (model || '').trim())
|
||||
.filter(Boolean);
|
||||
localInputs.models = normalizedModels;
|
||||
|
||||
if (
|
||||
parsedModelMapping &&
|
||||
typeof parsedModelMapping === 'object' &&
|
||||
!Array.isArray(parsedModelMapping)
|
||||
) {
|
||||
const modelSet = new Set(normalizedModels);
|
||||
const missingModels = Object.keys(parsedModelMapping)
|
||||
.map((key) => (key || '').trim())
|
||||
.filter((key) => key && !modelSet.has(key));
|
||||
const shouldPromptMissing =
|
||||
missingModels.length > 0 &&
|
||||
hasModelConfigChanged(normalizedModels, localInputs.model_mapping);
|
||||
if (shouldPromptMissing) {
|
||||
const confirmAction = await confirmMissingModelMappings(missingModels);
|
||||
if (confirmAction === 'cancel') {
|
||||
return;
|
||||
}
|
||||
if (confirmAction === 'add') {
|
||||
const updatedModels = Array.from(
|
||||
new Set([...normalizedModels, ...missingModels]),
|
||||
);
|
||||
localInputs.models = updatedModels;
|
||||
handleInputChange('models', updatedModels);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
||||
localInputs.base_url = localInputs.base_url.slice(
|
||||
0,
|
||||
@@ -2916,6 +3068,7 @@ const EditChannelModal = (props) => {
|
||||
visible={modelModalVisible}
|
||||
models={fetchedModels}
|
||||
selected={inputs.models}
|
||||
redirectModels={redirectModelList}
|
||||
onConfirm={(selectedModels) => {
|
||||
handleInputChange('models', selectedModels);
|
||||
showSuccess(t('模型列表已更新'));
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import React, { useState, useEffect, useMemo } from 'react';
|
||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
|
||||
import {
|
||||
Modal,
|
||||
@@ -28,12 +28,13 @@ import {
|
||||
Empty,
|
||||
Tabs,
|
||||
Collapse,
|
||||
Tooltip,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
IllustrationNoResultDark,
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { getModelCategories } from '../../../../helpers/render';
|
||||
|
||||
@@ -41,6 +42,7 @@ const ModelSelectModal = ({
|
||||
visible,
|
||||
models = [],
|
||||
selected = [],
|
||||
redirectModels = [],
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}) => {
|
||||
@@ -50,15 +52,54 @@ const ModelSelectModal = ({
|
||||
const [activeTab, setActiveTab] = useState('new');
|
||||
|
||||
const isMobile = useIsMobile();
|
||||
const normalizeModelName = (model) =>
|
||||
typeof model === 'string' ? model.trim() : '';
|
||||
const normalizedRedirectModels = useMemo(
|
||||
() =>
|
||||
Array.from(
|
||||
new Set(
|
||||
(redirectModels || [])
|
||||
.map((model) => normalizeModelName(model))
|
||||
.filter(Boolean),
|
||||
),
|
||||
),
|
||||
[redirectModels],
|
||||
);
|
||||
const normalizedSelectedSet = useMemo(() => {
|
||||
const set = new Set();
|
||||
(selected || []).forEach((model) => {
|
||||
const normalized = normalizeModelName(model);
|
||||
if (normalized) {
|
||||
set.add(normalized);
|
||||
}
|
||||
});
|
||||
return set;
|
||||
}, [selected]);
|
||||
const classificationSet = useMemo(() => {
|
||||
const set = new Set(normalizedSelectedSet);
|
||||
normalizedRedirectModels.forEach((model) => set.add(model));
|
||||
return set;
|
||||
}, [normalizedSelectedSet, normalizedRedirectModels]);
|
||||
const redirectOnlySet = useMemo(() => {
|
||||
const set = new Set();
|
||||
normalizedRedirectModels.forEach((model) => {
|
||||
if (!normalizedSelectedSet.has(model)) {
|
||||
set.add(model);
|
||||
}
|
||||
});
|
||||
return set;
|
||||
}, [normalizedRedirectModels, normalizedSelectedSet]);
|
||||
|
||||
const filteredModels = models.filter((m) =>
|
||||
m.toLowerCase().includes(keyword.toLowerCase()),
|
||||
String(m || '').toLowerCase().includes(keyword.toLowerCase()),
|
||||
);
|
||||
|
||||
// 分类模型:新获取的模型和已有模型
|
||||
const newModels = filteredModels.filter((model) => !selected.includes(model));
|
||||
const isExistingModel = (model) =>
|
||||
classificationSet.has(normalizeModelName(model));
|
||||
const newModels = filteredModels.filter((model) => !isExistingModel(model));
|
||||
const existingModels = filteredModels.filter((model) =>
|
||||
selected.includes(model),
|
||||
isExistingModel(model),
|
||||
);
|
||||
|
||||
// 同步外部选中值
|
||||
@@ -228,7 +269,20 @@ const ModelSelectModal = ({
|
||||
<div className='grid grid-cols-2 gap-x-4'>
|
||||
{categoryData.models.map((model) => (
|
||||
<Checkbox key={model} value={model} className='my-1'>
|
||||
{model}
|
||||
<span className='flex items-center gap-2'>
|
||||
<span>{model}</span>
|
||||
{redirectOnlySet.has(normalizeModelName(model)) && (
|
||||
<Tooltip
|
||||
position='top'
|
||||
content={t('来自模型重定向,尚未加入模型列表')}
|
||||
>
|
||||
<IconInfoCircle
|
||||
size='small'
|
||||
className='text-amber-500 cursor-help'
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
</span>
|
||||
</Checkbox>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -72,6 +72,7 @@ const EditUserModal = (props) => {
|
||||
password: '',
|
||||
github_id: '',
|
||||
oidc_id: '',
|
||||
discord_id: '',
|
||||
wechat_id: '',
|
||||
telegram_id: '',
|
||||
email: '',
|
||||
@@ -332,6 +333,7 @@ const EditUserModal = (props) => {
|
||||
<Row gutter={12}>
|
||||
{[
|
||||
'github_id',
|
||||
'discord_id',
|
||||
'oidc_id',
|
||||
'wechat_id',
|
||||
'email',
|
||||
|
||||
@@ -231,6 +231,17 @@ export async function getOAuthState() {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onDiscordOAuthClicked(client_id) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
const redirect_uri = `${window.location.origin}/oauth/discord`;
|
||||
const response_type = 'code';
|
||||
const scope = 'identify+openid';
|
||||
window.open(
|
||||
`https://discord.com/oauth2/authorize?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`,
|
||||
);
|
||||
}
|
||||
|
||||
export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
|
||||
@@ -482,6 +482,18 @@ export const useLogsData = () => {
|
||||
value: other.request_path,
|
||||
});
|
||||
}
|
||||
if (isAdminUser) {
|
||||
let localCountMode = '';
|
||||
if (other?.admin_info?.local_count_tokens) {
|
||||
localCountMode = t('本地计费');
|
||||
} else {
|
||||
localCountMode = t('上游返回');
|
||||
}
|
||||
expandDataLocal.push({
|
||||
key: t('计费模式'),
|
||||
value: localCountMode,
|
||||
});
|
||||
}
|
||||
expandDatesLocal[logs[i].key] = expandDataLocal;
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import frTranslation from './locales/fr.json';
|
||||
import zhTranslation from './locales/zh.json';
|
||||
import ruTranslation from './locales/ru.json';
|
||||
import jaTranslation from './locales/ja.json';
|
||||
import viTranslation from './locales/vi.json';
|
||||
|
||||
i18n
|
||||
.use(LanguageDetector)
|
||||
@@ -38,6 +39,7 @@ i18n
|
||||
fr: frTranslation,
|
||||
ru: ruTranslation,
|
||||
ja: jaTranslation,
|
||||
vi: viTranslation,
|
||||
},
|
||||
fallbackLng: 'zh',
|
||||
interpolation: {
|
||||
|
||||
2700
web/src/i18n/locales/vi.json
Normal file
2700
web/src/i18n/locales/vi.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -257,6 +257,7 @@
|
||||
"余额充值管理": "余额充值管理",
|
||||
"你似乎并没有修改什么": "你似乎并没有修改什么",
|
||||
"使用 GitHub 继续": "使用 GitHub 继续",
|
||||
"使用 Discord 继续": "使用 Discord 继续",
|
||||
"使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}",
|
||||
"使用 LinuxDO 继续": "使用 LinuxDO 继续",
|
||||
"使用 OIDC 继续": "使用 OIDC 继续",
|
||||
|
||||
Reference in New Issue
Block a user