mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-31 06:07:34 +00:00
Compare commits
56 Commits
v0.9.21
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8aedbb29c3 | ||
|
|
3f19f18dc9 | ||
|
|
a465597e78 | ||
|
|
dbfcb441f7 | ||
|
|
3fb2ba318d | ||
|
|
8f039b3a53 | ||
|
|
c939686509 | ||
|
|
07aff1fe02 | ||
|
|
5f27edcd19 | ||
|
|
f47d473e63 | ||
|
|
7a2bd38700 | ||
|
|
f8c40ecca6 | ||
|
|
2bc991685f | ||
|
|
87811a0493 | ||
|
|
0885597427 | ||
|
|
0952973887 | ||
|
|
6b30f042fa | ||
|
|
efb8f1f5b8 | ||
|
|
de3cf9893d | ||
|
|
fe02e9a066 | ||
|
|
84745d5ca4 | ||
|
|
cdb1c06ad2 | ||
|
|
ef0647285c | ||
|
|
33b1fad5f8 | ||
|
|
b899122dfe | ||
|
|
50c04a62f9 | ||
|
|
554b68484c | ||
|
|
6a1c046714 | ||
|
|
0b37bdddc6 | ||
|
|
563a426c00 | ||
|
|
f6a5d9ef7e | ||
|
|
a7d2450704 | ||
|
|
75fced3d9c | ||
|
|
5a1bbd1059 | ||
|
|
c133678cb1 | ||
|
|
1fc3c4b09d | ||
|
|
77c4c3e804 | ||
|
|
bc1f747418 | ||
|
|
62edac7c7f | ||
|
|
ff839df279 | ||
|
|
8b8511b19e | ||
|
|
7598753f4e | ||
|
|
68777bf05f | ||
|
|
b6217b22b0 | ||
|
|
196fa135fd | ||
|
|
ff3225ab44 | ||
|
|
ab36de3725 | ||
|
|
2b4617dc1b | ||
|
|
e169818404 | ||
|
|
c1a696e6f0 | ||
|
|
e07347ac53 | ||
|
|
fd38abd562 | ||
|
|
293c0277a8 | ||
|
|
344a799fcf | ||
|
|
d9b5748f80 | ||
|
|
fb3b27a626 |
@@ -63,10 +63,13 @@
|
||||
# 是否统计图片token
|
||||
# GET_MEDIA_TOKEN=true
|
||||
# 是否在非流(stream=false)情况下统计图片token
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=true
|
||||
# GET_MEDIA_TOKEN_NOT_STREAM=false
|
||||
# 设置 Dify 渠道是否输出工作流和节点信息到客户端
|
||||
# DIFY_DEBUG=true
|
||||
|
||||
# LinuxDo相关配置
|
||||
LINUX_DO_TOKEN_ENDPOINT=https://connect.linux.do/oauth2/token
|
||||
LINUX_DO_USER_ENDPOINT=https://connect.linux.do/api/user
|
||||
|
||||
# 节点类型
|
||||
# 如果是主节点则为master
|
||||
|
||||
24
.github/workflows/release.yml
vendored
24
.github/workflows/release.yml
vendored
@@ -22,6 +22,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -31,7 +35,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -40,13 +44,11 @@ jobs:
|
||||
- name: Build Backend (amd64)
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-$VERSION
|
||||
- name: Build Backend (arm64)
|
||||
run: |
|
||||
sudo apt-get update
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -y gcc-aarch64-linux-gnu
|
||||
VERSION=$(git describe --tags)
|
||||
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION' -extldflags '-static'" -o new-api-arm64-$VERSION
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -65,6 +67,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -75,7 +81,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -84,7 +90,6 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-X 'new-api/common.Version=$VERSION'" -o new-api-macos-$VERSION
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -105,6 +110,10 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Determine Version
|
||||
run: |
|
||||
VERSION=$(git describe --tags)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
@@ -114,7 +123,7 @@ jobs:
|
||||
run: |
|
||||
cd web
|
||||
bun install
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$(git describe --tags) bun run build
|
||||
DISABLE_ESLINT_PLUGIN='true' VITE_REACT_APP_VERSION=$VERSION bun run build
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
@@ -123,7 +132,6 @@ jobs:
|
||||
- name: Build Backend
|
||||
run: |
|
||||
go mod download
|
||||
VERSION=$(git describe --tags)
|
||||
go build -ldflags "-s -w -X 'new-api/common.Version=$VERSION'" -o new-api-$VERSION.exe
|
||||
- name: Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
@@ -132,5 +140,3 @@ jobs:
|
||||
files: new-api-*.exe
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -16,6 +16,8 @@ new-api
|
||||
tiktoken_cache
|
||||
.eslintcache
|
||||
.gocache
|
||||
.cache
|
||||
web/bun.lock
|
||||
|
||||
electron/node_modules
|
||||
electron/dist
|
||||
|
||||
@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
|
||||
|
||||
### 🔐 Authorization and Security
|
||||
|
||||
- 😈 Discord authorization login
|
||||
- 🤖 LinuxDO authorization login
|
||||
- 📱 Telegram authorization login
|
||||
- 🔑 OIDC unified authentication
|
||||
|
||||
@@ -193,6 +193,7 @@ docker run --name new-api -d --restart always \
|
||||
|
||||
### 🔐 授权与安全
|
||||
|
||||
- 😈 Discord 授权登录
|
||||
- 🤖 LinuxDO 授权登录
|
||||
- 📱 Telegram 授权登录
|
||||
- 🔑 OIDC 统一认证
|
||||
|
||||
@@ -2,7 +2,9 @@ package common
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -128,13 +130,13 @@ func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
|
||||
}
|
||||
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
boundary := ""
|
||||
if idx := strings.Index(contentType, "boundary="); idx != -1 {
|
||||
boundary = contentType[idx+9:]
|
||||
boundary, err := parseBoundary(contentType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
|
||||
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
|
||||
form, err := reader.ReadForm(multipartMemoryLimit())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -177,17 +179,16 @@ func parseFormData(data []byte, v any) error {
|
||||
|
||||
func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
boundary := ""
|
||||
if idx := strings.Index(contentType, "boundary="); idx != -1 {
|
||||
boundary = contentType[idx+9:]
|
||||
}
|
||||
|
||||
if boundary == "" {
|
||||
return Unmarshal(data, v) // Fallback to JSON
|
||||
boundary, err := parseBoundary(contentType)
|
||||
if err != nil {
|
||||
if errors.Is(err, errBoundaryNotFound) {
|
||||
return Unmarshal(data, v) // Fallback to JSON
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(data), boundary)
|
||||
form, err := reader.ReadForm(32 << 20) // 32 MB max memory
|
||||
form, err := reader.ReadForm(multipartMemoryLimit())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -203,3 +204,31 @@ func parseMultipartFormData(c *gin.Context, data []byte, v any) error {
|
||||
|
||||
return processFormMap(formMap, v)
|
||||
}
|
||||
|
||||
var errBoundaryNotFound = errors.New("multipart boundary not found")
|
||||
|
||||
// parseBoundary extracts the multipart boundary from the Content-Type header using mime.ParseMediaType
|
||||
func parseBoundary(contentType string) (string, error) {
|
||||
if contentType == "" {
|
||||
return "", errBoundaryNotFound
|
||||
}
|
||||
// Boundary-UUID / boundary-------xxxxxx
|
||||
_, params, err := mime.ParseMediaType(contentType)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
boundary, ok := params["boundary"]
|
||||
if !ok || boundary == "" {
|
||||
return "", errBoundaryNotFound
|
||||
}
|
||||
return boundary, nil
|
||||
}
|
||||
|
||||
// multipartMemoryLimit returns the configured multipart memory limit in bytes
|
||||
func multipartMemoryLimit() int64 {
|
||||
limitMB := constant.MaxFileDownloadMB
|
||||
if limitMB <= 0 {
|
||||
limitMB = 32
|
||||
}
|
||||
return int64(limitMB) << 20
|
||||
}
|
||||
|
||||
@@ -30,6 +30,11 @@ func printHelp() {
|
||||
func InitEnv() {
|
||||
flag.Parse()
|
||||
|
||||
envVersion := os.Getenv("VERSION")
|
||||
if envVersion != "" {
|
||||
Version = envVersion
|
||||
}
|
||||
|
||||
if *PrintVersion {
|
||||
fmt.Println(Version)
|
||||
os.Exit(0)
|
||||
@@ -111,8 +116,9 @@ func initConstantEnv() {
|
||||
constant.MaxFileDownloadMB = GetEnvOrDefault("MAX_FILE_DOWNLOAD_MB", 20)
|
||||
// ForceStreamOption 覆盖请求参数,强制返回usage信息
|
||||
constant.ForceStreamOption = GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true)
|
||||
constant.CountToken = GetEnvOrDefaultBool("CountToken", true)
|
||||
constant.GetMediaToken = GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||
constant.GetMediaTokenNotStream = GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", false)
|
||||
constant.UpdateTask = GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||
constant.AzureDefaultAPIVersion = GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2025-04-01-preview")
|
||||
constant.GeminiVisionMaxImageNum = GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
|
||||
|
||||
@@ -46,5 +46,7 @@ const (
|
||||
ContextKeyUsingGroup ContextKey = "group"
|
||||
ContextKeyUserName ContextKey = "username"
|
||||
|
||||
ContextKeyLocalCountTokens ContextKey = "local_count_tokens"
|
||||
|
||||
ContextKeySystemPromptOverride ContextKey = "system_prompt_override"
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ var StreamingTimeout int
|
||||
var DifyDebug bool
|
||||
var MaxFileDownloadMB int
|
||||
var ForceStreamOption bool
|
||||
var CountToken bool
|
||||
var GetMediaToken bool
|
||||
var GetMediaTokenNotStream bool
|
||||
var UpdateTask bool
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/relay/channel/volcengine"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -91,7 +92,7 @@ func GetAllChannels(c *gin.Context) {
|
||||
if tag == nil || *tag == "" {
|
||||
continue
|
||||
}
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort)
|
||||
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -192,13 +193,29 @@ func FetchUpstreamModels(c *gin.Context) {
|
||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||
case constant.ChannelTypeZhipu_v4:
|
||||
url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
|
||||
case constant.ChannelTypeVolcEngine:
|
||||
if baseURL == volcengine.DoubaoCodingPlan {
|
||||
url = fmt.Sprintf("%s/v1/models", volcengine.DoubaoCodingPlanOpenAIBaseURL)
|
||||
} else {
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
default:
|
||||
url = fmt.Sprintf("%s/v1/models", baseURL)
|
||||
}
|
||||
|
||||
// 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
|
||||
key, _, apiErr := channel.GetNextEnabledKey()
|
||||
if apiErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
|
||||
})
|
||||
return
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
|
||||
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||
var body []byte
|
||||
key := strings.Split(channel.Key, "\n")[0]
|
||||
switch channel.Type {
|
||||
case constant.ChannelTypeAnthropic:
|
||||
body, err = GetResponseBody("GET", url, channel, GetClaudeAuthHeader(key))
|
||||
@@ -271,7 +288,7 @@ func SearchChannels(c *gin.Context) {
|
||||
}
|
||||
for _, tag := range tags {
|
||||
if tag != nil && *tag != "" {
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort)
|
||||
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
|
||||
if err == nil {
|
||||
channelData = append(channelData, tagChannel...)
|
||||
}
|
||||
@@ -1021,7 +1038,7 @@ func GetTagModels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
channels, err := model.GetChannelsByTag(tag, false) // Assuming false for idSort is fine here
|
||||
channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
|
||||
223
controller/discord.go
Normal file
223
controller/discord.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/QuantumNous/new-api/setting/system_setting"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DiscordResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
type DiscordUser struct {
|
||||
UID string `json:"id"`
|
||||
ID string `json:"username"`
|
||||
Name string `json:"global_name"`
|
||||
}
|
||||
|
||||
func getDiscordUserInfoByCode(code string) (*DiscordUser, error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("无效的参数")
|
||||
}
|
||||
|
||||
values := url.Values{}
|
||||
values.Set("client_id", system_setting.GetDiscordSettings().ClientId)
|
||||
values.Set("client_secret", system_setting.GetDiscordSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/discord", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", "https://discord.com/api/v10/oauth2/token", strings.NewReader(formData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var discordResponse DiscordResponse
|
||||
err = json.NewDecoder(res.Body).Decode(&discordResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if discordResponse.AccessToken == "" {
|
||||
common.SysError("Discord 获取 Token 失败,请检查设置!")
|
||||
return nil, errors.New("Discord 获取 Token 失败,请检查设置!")
|
||||
}
|
||||
|
||||
req, err = http.NewRequest("GET", "https://discord.com/api/v10/users/@me", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+discordResponse.AccessToken)
|
||||
res2, err := client.Do(req)
|
||||
if err != nil {
|
||||
common.SysLog(err.Error())
|
||||
return nil, errors.New("无法连接至 Discord 服务器,请稍后重试!")
|
||||
}
|
||||
defer res2.Body.Close()
|
||||
if res2.StatusCode != http.StatusOK {
|
||||
common.SysError("Discord 获取用户信息失败!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息失败!请检查设置!")
|
||||
}
|
||||
|
||||
var discordUser DiscordUser
|
||||
err = json.NewDecoder(res2.Body).Decode(&discordUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if discordUser.UID == "" || discordUser.ID == "" {
|
||||
common.SysError("Discord 获取用户信息为空!请检查设置!")
|
||||
return nil, errors.New("Discord 获取用户信息为空!请检查设置!")
|
||||
}
|
||||
return &discordUser, nil
|
||||
}
|
||||
|
||||
func DiscordOAuth(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
state := c.Query("state")
|
||||
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "state is empty or not same",
|
||||
})
|
||||
return
|
||||
}
|
||||
username := session.Get("username")
|
||||
if username != nil {
|
||||
DiscordBind(c)
|
||||
return
|
||||
}
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
err := user.FillUserByDiscordId()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if common.RegisterEnabled {
|
||||
if discordUser.ID != "" {
|
||||
user.Username = discordUser.ID
|
||||
} else {
|
||||
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
|
||||
}
|
||||
if discordUser.Name != "" {
|
||||
user.DisplayName = discordUser.Name
|
||||
} else {
|
||||
user.DisplayName = "Discord User"
|
||||
}
|
||||
err := user.Insert(0)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员关闭了新用户注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if user.Status != common.UserStatusEnabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "用户已被封禁",
|
||||
"success": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
setupLogin(&user, c)
|
||||
}
|
||||
|
||||
func DiscordBind(c *gin.Context) {
|
||||
if !system_setting.GetDiscordSettings().Enabled {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "管理员未开启通过 Discord 登录以及注册",
|
||||
})
|
||||
return
|
||||
}
|
||||
code := c.Query("code")
|
||||
discordUser, err := getDiscordUserInfoByCode(code)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user := model.User{
|
||||
DiscordId: discordUser.UID,
|
||||
}
|
||||
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "该 Discord 账户已被绑定",
|
||||
})
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
id := session.Get("id")
|
||||
user.Id = id.(int)
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
user.DiscordId = discordUser.UID
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "bind",
|
||||
})
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
client := http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Timeout: 20 * time.Second,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
|
||||
@@ -84,7 +84,7 @@ func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error)
|
||||
}
|
||||
|
||||
// Get access token using Basic auth
|
||||
tokenEndpoint := "https://connect.linux.do/oauth2/token"
|
||||
tokenEndpoint := common.GetEnvOrDefaultString("LINUX_DO_TOKEN_ENDPOINT", "https://connect.linux.do/oauth2/token")
|
||||
credentials := common.LinuxDOClientId + ":" + common.LinuxDOClientSecret
|
||||
basicAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(credentials))
|
||||
|
||||
@@ -129,7 +129,7 @@ func getLinuxdoUserInfoByCode(code string, c *gin.Context) (*LinuxdoUser, error)
|
||||
}
|
||||
|
||||
// Get user info
|
||||
userEndpoint := "https://connect.linux.do/api/user"
|
||||
userEndpoint := common.GetEnvOrDefaultString("LINUX_DO_USER_ENDPOINT", "https://connect.linux.do/api/user")
|
||||
req, err = http.NewRequest("GET", userEndpoint, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -52,6 +52,8 @@ func GetStatus(c *gin.Context) {
|
||||
"email_verification": common.EmailVerificationEnabled,
|
||||
"github_oauth": common.GitHubOAuthEnabled,
|
||||
"github_client_id": common.GitHubClientId,
|
||||
"discord_oauth": system_setting.GetDiscordSettings().Enabled,
|
||||
"discord_client_id": system_setting.GetDiscordSettings().ClientId,
|
||||
"linuxdo_oauth": common.LinuxDOOAuthEnabled,
|
||||
"linuxdo_client_id": common.LinuxDOClientId,
|
||||
"linuxdo_minimum_trust_level": common.LinuxDOMinimumTrustLevel,
|
||||
|
||||
@@ -71,6 +71,14 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
case "discord.enabled":
|
||||
if option.Value == "true" && system_setting.GetDiscordSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 Discord OAuth,请先填入 Discord Client Id 以及 Discord Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "oidc.enabled":
|
||||
if option.Value == "true" && system_setting.GetOIDCSettings().ClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -453,6 +453,7 @@ func GetSelf(c *gin.Context) {
|
||||
"status": user.Status,
|
||||
"email": user.Email,
|
||||
"github_id": user.GitHubId,
|
||||
"discord_id": user.DiscordId,
|
||||
"oidc_id": user.OidcId,
|
||||
"wechat_id": user.WeChatId,
|
||||
"telegram_id": user.TelegramId,
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
| 方法 | 路径 | 鉴权 | 说明 |
|
||||
|------|------|------|------|
|
||||
| GET | /api/oauth/github | 公开 | GitHub OAuth 跳转 |
|
||||
| GET | /api/oauth/discord | 公开 | Discord 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/oidc | 公开 | OIDC 通用 OAuth 跳转 |
|
||||
| GET | /api/oauth/linuxdo | 公开 | LinuxDo OAuth 跳转 |
|
||||
| GET | /api/oauth/wechat | 公开 | 微信扫码登录跳转 |
|
||||
|
||||
@@ -141,6 +141,39 @@ func (r *GeminiChatRequest) SetTools(tools []GeminiChatTool) {
|
||||
type GeminiThinkingConfig struct {
|
||||
IncludeThoughts bool `json:"includeThoughts,omitempty"`
|
||||
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
|
||||
// TODO Conflict with thinkingbudget.
|
||||
ThinkingLevel json.RawMessage `json:"thinkingLevel,omitempty"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON allows GeminiThinkingConfig to accept both snake_case and camelCase fields.
|
||||
func (c *GeminiThinkingConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiThinkingConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
IncludeThoughtsSnake *bool `json:"include_thoughts,omitempty"`
|
||||
ThinkingBudgetSnake *int `json:"thinking_budget,omitempty"`
|
||||
ThinkingLevelSnake json.RawMessage `json:"thinking_level,omitempty"`
|
||||
}
|
||||
|
||||
if err := common.Unmarshal(data, &aux); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*c = GeminiThinkingConfig(aux.Alias)
|
||||
|
||||
if aux.IncludeThoughtsSnake != nil {
|
||||
c.IncludeThoughts = *aux.IncludeThoughtsSnake
|
||||
}
|
||||
|
||||
if aux.ThinkingBudgetSnake != nil {
|
||||
c.ThinkingBudget = aux.ThinkingBudgetSnake
|
||||
}
|
||||
|
||||
if len(aux.ThinkingLevelSnake) > 0 {
|
||||
c.ThinkingLevel = aux.ThinkingLevelSnake
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
|
||||
@@ -182,8 +215,12 @@ type FunctionCall struct {
|
||||
}
|
||||
|
||||
type GeminiFunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
WillContinue json.RawMessage `json:"willContinue,omitempty"`
|
||||
Scheduling json.RawMessage `json:"scheduling,omitempty"`
|
||||
Parts json.RawMessage `json:"parts,omitempty"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiPartExecutableCode struct {
|
||||
@@ -202,11 +239,15 @@ type GeminiFileData struct {
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Thought bool `json:"thought,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
ThoughtSignature json.RawMessage `json:"thoughtSignature,omitempty"`
|
||||
FunctionResponse *GeminiFunctionResponse `json:"functionResponse,omitempty"`
|
||||
// Optional. Media resolution for the input media.
|
||||
MediaResolution json.RawMessage `json:"mediaResolution,omitempty"`
|
||||
VideoMetadata json.RawMessage `json:"videoMetadata,omitempty"`
|
||||
FileData *GeminiFileData `json:"fileData,omitempty"`
|
||||
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
|
||||
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
|
||||
|
||||
@@ -66,10 +66,11 @@ type GeneralOpenAIRequest struct {
|
||||
// 注意:默认过滤此字段以保护用户隐私,但过滤后可能导致 Codex 无法正常使用
|
||||
Store json.RawMessage `json:"store,omitempty"`
|
||||
// Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the user field
|
||||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||||
LogitBias json.RawMessage `json:"logit_bias,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
Prediction json.RawMessage `json:"prediction,omitempty"`
|
||||
PromptCacheKey string `json:"prompt_cache_key,omitempty"`
|
||||
PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
|
||||
LogitBias json.RawMessage `json:"logit_bias,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
Prediction json.RawMessage `json:"prediction,omitempty"`
|
||||
// gemini
|
||||
ExtraBody json.RawMessage `json:"extra_body,omitempty"`
|
||||
//xai
|
||||
@@ -798,19 +799,20 @@ type OpenAIResponsesRequest struct {
|
||||
PreviousResponseID string `json:"previous_response_id,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
// 服务层级字段,用于指定 API 服务等级。允许透传可能导致实际计费高于预期,默认应过滤
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store json.RawMessage `json:"store,omitempty"`
|
||||
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
Store json.RawMessage `json:"store,omitempty"`
|
||||
PromptCacheKey json.RawMessage `json:"prompt_cache_key,omitempty"`
|
||||
PromptCacheRetention json.RawMessage `json:"prompt_cache_retention,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
|
||||
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
}
|
||||
|
||||
func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
@@ -895,6 +897,12 @@ type Reasoning struct {
|
||||
Summary string `json:"summary,omitempty"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content json.RawMessage `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
type MediaInput struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
@@ -913,7 +921,7 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||
return nil
|
||||
}
|
||||
|
||||
var inputs []MediaInput
|
||||
var mediaInputs []MediaInput
|
||||
|
||||
// Try string first
|
||||
// if str, ok := common.GetJsonType(r.Input); ok {
|
||||
@@ -923,60 +931,74 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
|
||||
if common.GetJsonType(r.Input) == "string" {
|
||||
var str string
|
||||
_ = common.Unmarshal(r.Input, &str)
|
||||
inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
|
||||
return inputs
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
|
||||
return mediaInputs
|
||||
}
|
||||
|
||||
// Try array of parts
|
||||
if common.GetJsonType(r.Input) == "array" {
|
||||
var array []any
|
||||
_ = common.Unmarshal(r.Input, &array)
|
||||
for _, itemAny := range array {
|
||||
// Already parsed MediaInput
|
||||
if media, ok := itemAny.(MediaInput); ok {
|
||||
inputs = append(inputs, media)
|
||||
continue
|
||||
var inputs []Input
|
||||
_ = common.Unmarshal(r.Input, &inputs)
|
||||
for _, input := range inputs {
|
||||
if common.GetJsonType(input.Content) == "string" {
|
||||
var str string
|
||||
_ = common.Unmarshal(input.Content, &str)
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: str})
|
||||
}
|
||||
// Generic map
|
||||
item, ok := itemAny.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typeVal, ok := item["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch typeVal {
|
||||
case "input_text":
|
||||
text, _ := item["text"].(string)
|
||||
inputs = append(inputs, MediaInput{Type: "input_text", Text: text})
|
||||
case "input_image":
|
||||
// image_url may be string or object with url field
|
||||
var imageUrl string
|
||||
switch v := item["image_url"].(type) {
|
||||
case string:
|
||||
imageUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
imageUrl = url
|
||||
|
||||
if common.GetJsonType(input.Content) == "array" {
|
||||
var array []any
|
||||
_ = common.Unmarshal(input.Content, &array)
|
||||
for _, itemAny := range array {
|
||||
// Already parsed MediaContent
|
||||
if media, ok := itemAny.(MediaInput); ok {
|
||||
mediaInputs = append(mediaInputs, media)
|
||||
continue
|
||||
}
|
||||
|
||||
// Generic map
|
||||
item, ok := itemAny.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
typeVal, ok := item["type"].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch typeVal {
|
||||
case "input_text":
|
||||
text, _ := item["text"].(string)
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_text", Text: text})
|
||||
case "input_image":
|
||||
// image_url may be string or object with url field
|
||||
var imageUrl string
|
||||
switch v := item["image_url"].(type) {
|
||||
case string:
|
||||
imageUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
imageUrl = url
|
||||
}
|
||||
}
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||
case "input_file":
|
||||
// file_url may be string or object with url field
|
||||
var fileUrl string
|
||||
switch v := item["file_url"].(type) {
|
||||
case string:
|
||||
fileUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
fileUrl = url
|
||||
}
|
||||
}
|
||||
mediaInputs = append(mediaInputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||
}
|
||||
}
|
||||
inputs = append(inputs, MediaInput{Type: "input_image", ImageUrl: imageUrl})
|
||||
case "input_file":
|
||||
// file_url may be string or object with url field
|
||||
var fileUrl string
|
||||
switch v := item["file_url"].(type) {
|
||||
case string:
|
||||
fileUrl = v
|
||||
case map[string]any:
|
||||
if url, ok := v["url"].(string); ok {
|
||||
fileUrl = url
|
||||
}
|
||||
}
|
||||
inputs = append(inputs, MediaInput{Type: "input_file", FileUrl: fileUrl})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return inputs
|
||||
return mediaInputs
|
||||
}
|
||||
|
||||
6
electron/package-lock.json
generated
6
electron/package-lock.json
generated
@@ -2784,9 +2784,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/js-yaml": {
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.0.tgz",
|
||||
"integrity": "sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==",
|
||||
"version": "4.1.1",
|
||||
"resolved": "https://registry.npmjs.org/js-yaml/-/js-yaml-4.1.1.tgz",
|
||||
"integrity": "sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
||||
10
go.mod
10
go.mod
@@ -43,10 +43,10 @@ require (
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.6.2
|
||||
github.com/yapingcat/gomedia v0.0.0-20240906162731-17feea57090c
|
||||
golang.org/x/crypto v0.42.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/image v0.23.0
|
||||
golang.org/x/net v0.43.0
|
||||
golang.org/x/sync v0.17.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/sync v0.18.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/gorm v1.25.2
|
||||
@@ -111,8 +111,8 @@ require (
|
||||
github.com/yusufpapurcu/wmi v1.2.3 // indirect
|
||||
golang.org/x/arch v0.21.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/text v0.29.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
|
||||
20
go.sum
20
go.sum
@@ -281,18 +281,18 @@ go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
golang.org/x/arch v0.21.0 h1:iTC9o7+wP6cPWpDWkivCvQFGAHDQ59SrSxsLPcnkArw=
|
||||
golang.org/x/arch v0.21.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
|
||||
golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
|
||||
golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68=
|
||||
golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -304,15 +304,15 @@ golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
|
||||
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
|
||||
@@ -272,13 +272,17 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan
|
||||
return channels, err
|
||||
}
|
||||
|
||||
func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) {
|
||||
func GetChannelsByTag(tag string, idSort bool, selectAll bool) ([]*Channel, error) {
|
||||
var channels []*Channel
|
||||
order := "priority desc"
|
||||
if idSort {
|
||||
order = "id desc"
|
||||
}
|
||||
err := DB.Where("tag = ?", tag).Order(order).Find(&channels).Error
|
||||
query := DB.Where("tag = ?", tag).Order(order)
|
||||
if !selectAll {
|
||||
query = query.Omit("key")
|
||||
}
|
||||
err := query.Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
@@ -728,7 +732,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
|
||||
return err
|
||||
}
|
||||
if shouldReCreateAbilities {
|
||||
channels, err := GetChannelsByTag(updatedTag, false)
|
||||
channels, err := GetChannelsByTag(updatedTag, false, false)
|
||||
if err == nil {
|
||||
for _, channel := range channels {
|
||||
err = channel.UpdateAbilities(nil)
|
||||
|
||||
@@ -27,6 +27,7 @@ type User struct {
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
|
||||
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
|
||||
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
|
||||
TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"`
|
||||
@@ -539,6 +540,14 @@ func (user *User) FillUserByGitHubId() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByDiscordId() error {
|
||||
if user.DiscordId == "" {
|
||||
return errors.New("discord id 为空!")
|
||||
}
|
||||
DB.Where(User{DiscordId: user.DiscordId}).First(user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) FillUserByOidcId() error {
|
||||
if user.OidcId == "" {
|
||||
return errors.New("oidc id 为空!")
|
||||
@@ -578,6 +587,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
|
||||
return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsDiscordIdAlreadyTaken(discordId string) bool {
|
||||
return DB.Unscoped().Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
func IsOidcIdAlreadyTaken(oidcId string) bool {
|
||||
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
|
||||
}
|
||||
|
||||
@@ -47,7 +47,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesEdits:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
if isWanModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
}
|
||||
case constant.RelayModeCompletions:
|
||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
|
||||
default:
|
||||
@@ -71,6 +75,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
req.Set("Content-Type", "application/json")
|
||||
}
|
||||
return nil
|
||||
@@ -107,6 +114,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
return oaiFormEdit2WanxImageEdit(c, info, request)
|
||||
}
|
||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||
// 如果用户使用表单,则需要解析表单数据
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
@@ -161,7 +171,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
case constant.RelayModeImagesEdits:
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
if isWanModel(info.OriginModelName) {
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
}
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
|
||||
@@ -112,6 +112,19 @@ type AliImageInput struct {
|
||||
Messages []AliMessage `json:"messages,omitempty"`
|
||||
}
|
||||
|
||||
type WanImageInput struct {
|
||||
Prompt string `json:"prompt"` // 必需:文本提示词,描述生成图像中期望包含的元素和视觉特点
|
||||
Images []string `json:"images"` // 必需:图像URL数组,长度不超过2,支持HTTP/HTTPS URL或Base64编码
|
||||
NegativePrompt string `json:"negative_prompt,omitempty"` // 可选:反向提示词,描述不希望在画面中看到的内容
|
||||
}
|
||||
|
||||
type WanImageParameters struct {
|
||||
N int `json:"n,omitempty"` // 生成图片数量,取值范围1-4,默认4
|
||||
Watermark *bool `json:"watermark,omitempty"` // 是否添加水印标识,默认false
|
||||
Seed int `json:"seed,omitempty"` // 随机数种子,取值范围[0, 2147483647]
|
||||
Strength float64 `json:"strength,omitempty"` // 修改幅度 0.0-1.0,默认0.5(部分模型支持)
|
||||
}
|
||||
|
||||
type AliRerankParameters struct {
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
|
||||
@@ -58,11 +58,7 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
if _, err := c.MultipartForm(); err != nil {
|
||||
@@ -127,7 +123,18 @@ func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, reque
|
||||
imageBase64s = append(imageBase64s, dataURL)
|
||||
image.Close()
|
||||
}
|
||||
return imageBase64s, nil
|
||||
}
|
||||
|
||||
func oaiFormEdit2AliImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
|
||||
imageBase64s, err := getImageBase64sFromForm(c, "image")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
|
||||
}
|
||||
//dto.MediaContent{}
|
||||
mediaContents := make([]AliMediaContent, len(imageBase64s))
|
||||
for i, b64 := range imageBase64s {
|
||||
|
||||
39
relay/channel/ali/image_wan.go
Normal file
39
relay/channel/ali/image_wan.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
var err error
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
wanInput := WanImageInput{
|
||||
Prompt: request.Prompt,
|
||||
}
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &wanInput); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
|
||||
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
|
||||
}
|
||||
wanParams := WanImageParameters{
|
||||
N: int(request.N),
|
||||
}
|
||||
imageRequest.Input = wanInput
|
||||
imageRequest.Parameters = wanParams
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func isWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan")
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
)
|
||||
|
||||
type AwsClaudeRequest struct {
|
||||
@@ -34,17 +37,19 @@ func formatRequest(requestBody io.Reader, requestHeader http.Header) (*AwsClaude
|
||||
awsClaudeRequest.AnthropicVersion = "bedrock-2023-05-31"
|
||||
|
||||
// check header anthropic-beta
|
||||
anthropicBetaValues := requestHeader.Values("anthropic-beta")
|
||||
anthropicBetaValues := requestHeader.Get("anthropic-beta")
|
||||
if len(anthropicBetaValues) > 0 {
|
||||
betaJson, err := json.Marshal(anthropicBetaValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tempArray []string
|
||||
if err := json.Unmarshal(betaJson, &tempArray); err == nil && len(tempArray) != 0 && len(betaJson) > 0 {
|
||||
awsClaudeRequest.AnthropicBeta = json.RawMessage(betaJson)
|
||||
tempArray = strings.Split(anthropicBetaValues, ",")
|
||||
if len(tempArray) > 0 {
|
||||
betaJson, err := json.Marshal(tempArray)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
awsClaudeRequest.AnthropicBeta = betaJson
|
||||
}
|
||||
}
|
||||
logger.LogJson(context.Background(), "json", awsClaudeRequest)
|
||||
return &awsClaudeRequest, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -673,7 +673,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||
|
||||
if requestMode == RequestModeCompletion {
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
if claudeInfo.Usage.PromptTokens == 0 {
|
||||
//上游出错
|
||||
@@ -682,7 +682,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
||||
if common.DebugEnabled {
|
||||
common.SysLog("claude response usage is not complete, maybe upstream error")
|
||||
}
|
||||
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
claudeInfo.Usage = service.ResponseText2Usage(c, claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := helper.ObjectData(c, response)
|
||||
@@ -105,7 +105,7 @@ func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response)
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage := service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = helper.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
|
||||
@@ -165,7 +165,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
}
|
||||
})
|
||||
if usage.PromptTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
|
||||
helper.Done(c)
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
|
||||
@@ -246,7 +246,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
|
||||
})
|
||||
helper.Done(c)
|
||||
if usage.TotalTokens == 0 {
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
}
|
||||
usage.CompletionTokens += nodeToken
|
||||
return usage, nil
|
||||
|
||||
@@ -177,7 +177,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
||||
geminiRequest, err := CovertGemini2OpenAI(c, *request, info)
|
||||
geminiRequest, err := CovertOpenAI2Gemini(c, *request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ var ModelList = []string{
|
||||
"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
|
||||
// preview version
|
||||
"gemini-2.0-flash-lite-preview",
|
||||
"gemini-3-pro-preview",
|
||||
// gemini exp
|
||||
"gemini-exp-1206",
|
||||
// flash exp
|
||||
@@ -31,7 +32,7 @@ var SafetySettingList = []string{
|
||||
"HARM_CATEGORY_HATE_SPEECH",
|
||||
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY",
|
||||
//"HARM_CATEGORY_CIVIC_INTEGRITY", This item is deprecated!
|
||||
}
|
||||
|
||||
var ChannelName = "google gemini"
|
||||
|
||||
@@ -3,9 +3,9 @@ package gemini
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -13,8 +13,6 @@ import (
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -77,6 +75,8 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
TotalTokens: info.PromptTokens,
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
|
||||
if info.IsGeminiBatchEmbedding {
|
||||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||||
@@ -97,80 +97,15 @@ func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *rel
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
if part.Text != "" {
|
||||
responseText.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
} else if detail.Modality == "TEXT" {
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.StringData(c, data)
|
||||
err := helper.StringData(c, data)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
info.SendResponseCount++
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||
//helper.Done(c)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"video/flv": true,
|
||||
}
|
||||
|
||||
const thoughtSignatureBypassValue = "context_engineering_is_the_way_to_go"
|
||||
|
||||
// Gemini 允许的思考预算范围
|
||||
const (
|
||||
pro25MinBudget = 128
|
||||
@@ -181,7 +183,7 @@ func ThinkingAdaptor(geminiRequest *dto.GeminiChatRequest, info *relaycommon.Rel
|
||||
}
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
func CovertOpenAI2Gemini(c *gin.Context, textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*dto.GeminiChatRequest, error) {
|
||||
|
||||
geminiRequest := dto.GeminiChatRequest{
|
||||
Contents: make([]dto.GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
@@ -193,6 +195,10 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
},
|
||||
}
|
||||
|
||||
attachThoughtSignature := (info.ChannelType == constant.ChannelTypeGemini ||
|
||||
info.ChannelType == constant.ChannelTypeVertexAi) &&
|
||||
model_setting.GetGeminiSettings().FunctionCallThoughtSignatureEnabled
|
||||
|
||||
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
|
||||
geminiRequest.GenerationConfig.ResponseModalities = []string{
|
||||
"TEXT",
|
||||
@@ -371,6 +377,8 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
content := dto.GeminiChatContent{
|
||||
Role: message.Role,
|
||||
}
|
||||
shouldAttachThoughtSignature := attachThoughtSignature && (message.Role == "assistant" || message.Role == "model")
|
||||
signatureAttached := false
|
||||
// isToolCall := false
|
||||
if message.ToolCalls != nil {
|
||||
// message.Role = "model"
|
||||
@@ -388,6 +396,10 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
if shouldAttachThoughtSignature && !signatureAttached && hasFunctionCallContent(toolCall.FunctionCall) && len(toolCall.ThoughtSignature) == 0 {
|
||||
toolCall.ThoughtSignature = json.RawMessage(strconv.Quote(thoughtSignatureBypassValue))
|
||||
signatureAttached = true
|
||||
}
|
||||
parts = append(parts, toolCall)
|
||||
tool_call_ids[call.ID] = call.Function.Name
|
||||
}
|
||||
@@ -496,6 +508,28 @@ func CovertGemini2OpenAI(c *gin.Context, textRequest dto.GeneralOpenAIRequest, i
|
||||
return &geminiRequest, nil
|
||||
}
|
||||
|
||||
func hasFunctionCallContent(call *dto.FunctionCall) bool {
|
||||
if call == nil {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(call.FunctionName) != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch v := call.Arguments.(type) {
|
||||
case nil:
|
||||
return false
|
||||
case string:
|
||||
return strings.TrimSpace(v) != ""
|
||||
case map[string]interface{}:
|
||||
return len(v) > 0
|
||||
case []interface{}:
|
||||
return len(v) > 0
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to get a list of supported MIME types for error messages
|
||||
func getSupportedMimeTypesList() []string {
|
||||
keys := make([]string, 0, len(geminiSupportedMimeTypes))
|
||||
@@ -920,14 +954,10 @@ func handleFinalStream(c *gin.Context, info *relaycommon.RelayInfo, resp *dto.Ch
|
||||
return nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
// responseText := ""
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
responseText := strings.Builder{}
|
||||
func geminiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response, callback func(data string, geminiResponse *dto.GeminiChatResponse) bool) (*dto.Usage, *types.NewAPIError) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
finishReason := constant.FinishReasonStop
|
||||
responseText := strings.Builder{}
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse dto.GeminiChatResponse
|
||||
@@ -937,6 +967,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
@@ -948,14 +979,10 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(&geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
@@ -966,6 +993,45 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return callback(data, &geminiResponse)
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 1400
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
if usage.TotalTokens > 0 {
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
}
|
||||
|
||||
if usage.CompletionTokens <= 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(c, responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
id := helper.GetResponseID(c)
|
||||
createAt := common.GetTimestamp()
|
||||
finishReason := constant.FinishReasonStop
|
||||
|
||||
usage, err := geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
|
||||
response, isStop := streamResponseGeminiChat2OpenAI(geminiResponse)
|
||||
|
||||
response.Id = id
|
||||
response.Created = createAt
|
||||
response.Model = info.UpstreamModelName
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("info.SendResponseCount = %d", info.SendResponseCount))
|
||||
if info.SendResponseCount == 0 {
|
||||
// send first response
|
||||
@@ -981,7 +1047,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
emptyResponse.Choices[0].Delta.ToolCalls = copiedToolCalls
|
||||
}
|
||||
finishReason = constant.FinishReasonToolCalls
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
err := handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
@@ -991,14 +1057,14 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
response.Choices[0].FinishReason = nil
|
||||
}
|
||||
} else {
|
||||
err = handleStream(c, info, emptyResponse)
|
||||
err := handleStream(c, info, emptyResponse)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = handleStream(c, info, response)
|
||||
err := handleStream(c, info, response)
|
||||
if err != nil {
|
||||
logger.LogError(c, err.Error())
|
||||
}
|
||||
@@ -1008,40 +1074,15 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
|
||||
return true
|
||||
})
|
||||
|
||||
if info.SendResponseCount == 0 {
|
||||
// 空补全,报错不计费
|
||||
// empty response, throw an error
|
||||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
if usage.CompletionTokens == 0 {
|
||||
str := responseText.String()
|
||||
if len(str) > 0 {
|
||||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
// 空补全,不需要使用量
|
||||
usage = &dto.Usage{}
|
||||
}
|
||||
if err != nil {
|
||||
return usage, err
|
||||
}
|
||||
|
||||
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||
err := handleFinalStream(c, info, response)
|
||||
if err != nil {
|
||||
common.SysLog("send final response failed: " + err.Error())
|
||||
handleErr := handleFinalStream(c, info, response)
|
||||
if handleErr != nil {
|
||||
common.SysLog("send final response failed: " + handleErr.Error())
|
||||
}
|
||||
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
|
||||
// helper.Done(c)
|
||||
//}
|
||||
//resp.Body.Close()
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
@@ -18,10 +19,26 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// 辅助函数
|
||||
// HandleStreamFormat processes a streaming response payload according to the provided RelayInfo and forwards it to the appropriate format-specific handler.
|
||||
//
|
||||
// It increments info.SendResponseCount, optionally converts OpenRouter "reasoning" fields to "reasoning_content" when the channel is OpenRouter and OpenRouterConvertToOpenAI is enabled, and then dispatches the (possibly modified) JSON string to the handler for the configured RelayFormat (OpenAI, Claude, or Gemini). It returns any error produced by the selected handler or nil if no handler is invoked.
|
||||
func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
|
||||
info.SendResponseCount++
|
||||
|
||||
// OpenRouter reasoning 字段转换:reasoning -> reasoning_content
|
||||
// 仅当启用转换为OpenAI兼容格式时执行
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.OpenRouterConvertToOpenAI {
|
||||
var streamResponse dto.ChatCompletionsStreamResponse
|
||||
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err == nil {
|
||||
convertOpenRouterReasoningFieldsStream(&streamResponse)
|
||||
// 重新序列化为JSON
|
||||
newData, err := common.Marshal(streamResponse)
|
||||
if err == nil {
|
||||
data = string(newData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatOpenAI:
|
||||
return sendStreamData(c, info, data, forceFormat, thinkToContent)
|
||||
@@ -253,9 +270,26 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponsesStreamData sends a non-empty data chunk for the given stream response to the client.
|
||||
// If data is empty, it returns without sending anything.
|
||||
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
|
||||
if data == "" {
|
||||
return
|
||||
}
|
||||
helper.ResponseChunkData(c, streamResponse, data)
|
||||
}
|
||||
|
||||
// convertOpenRouterReasoningFieldsStream converts each choice's `Delta` in a streaming ChatCompletions response
|
||||
// by normalizing any `reasoning` fields into `reasoning_content`.
|
||||
// It applies ConvertReasoningField to every choice's Delta and is a no-op if `response` is nil or has no choices.
|
||||
func convertOpenRouterReasoningFieldsStream(response *dto.ChatCompletionsStreamResponse) {
|
||||
if response == nil || len(response.Choices) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有choices,对每个Delta使用统一的泛型函数进行转换
|
||||
for i := range response.Choices {
|
||||
choice := &response.Choices[i]
|
||||
ConvertReasoningField(&choice.Delta)
|
||||
}
|
||||
}
|
||||
35
relay/channel/openai/reasoning_converter.go
Normal file
35
relay/channel/openai/reasoning_converter.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package openai
|
||||
|
||||
// ReasoningHolder 定义一个通用的接口,用于操作包含reasoning字段的结构体
|
||||
type ReasoningHolder interface {
|
||||
// 获取reasoning字段的值
|
||||
GetReasoning() string
|
||||
// 设置reasoning字段的值
|
||||
SetReasoning(reasoning string)
|
||||
// 获取reasoning_content字段的值
|
||||
GetReasoningContent() string
|
||||
// 设置reasoning_content字段的值
|
||||
SetReasoningContent(reasoningContent string)
|
||||
}
|
||||
|
||||
// ConvertReasoningField 通用的reasoning字段转换函数
|
||||
// 将reasoning字段的内容移动到reasoning_content字段
|
||||
// ConvertReasoningField moves the holder's reasoning into its reasoning content and clears the original reasoning field.
|
||||
// If GetReasoning returns an empty string, the holder is unchanged. When clearing, types that implement SetReasoningToNil()
|
||||
// will have that method invoked; otherwise SetReasoning("") is used.
|
||||
func ConvertReasoningField[T ReasoningHolder](holder T) {
|
||||
reasoning := holder.GetReasoning()
|
||||
if reasoning != "" {
|
||||
holder.SetReasoningContent(reasoning)
|
||||
}
|
||||
|
||||
// 使用类型断言来智能清理reasoning字段
|
||||
switch h := any(holder).(type) {
|
||||
case interface{ SetReasoningToNil() }:
|
||||
// 流式响应:指针类型,设为nil
|
||||
h.SetReasoningToNil()
|
||||
default:
|
||||
// 非流式响应:值类型,设为空字符串
|
||||
holder.SetReasoning("")
|
||||
}
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
}
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
@@ -194,6 +194,25 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
// OpenaiHandler processes an upstream OpenAI-like HTTP response, normalizes or infers token usage,
|
||||
// optionally converts OpenRouter reasoning fields to OpenAI-compatible `reasoning_content`, adapts
|
||||
// the response to the configured relay format (OpenAI, Claude, or Gemini), writes the final body
|
||||
// to the client, and returns the computed usage.
|
||||
//
|
||||
// It will:
|
||||
// - Handle OpenRouter enterprise wrapper responses when the channel is OpenRouter Enterprise.
|
||||
// - Unmarshal the upstream body into an internal simple response and, when configured,
|
||||
// convert OpenRouter `reasoning` fields into `reasoning_content`.
|
||||
// - If usage prompt tokens are missing, infer completion tokens by counting tokens in choices
|
||||
// (falling back to per-choice text token counting) and set Prompt/Completion/Total tokens.
|
||||
// - Apply channel-specific post-processing to usage (cached token adjustments).
|
||||
// - Depending on RelayFormat and channel settings, inject updated usage into the body,
|
||||
// reserialize the converted simple response when ForceFormat is enabled or when OpenRouter
|
||||
// conversion was applied, or convert the response to Claude/Gemini formats.
|
||||
// - Write the final response body to the client via a graceful copy helper.
|
||||
//
|
||||
// Returns the final usage (possibly inferred or modified) or a NewAPIError describing any failure
|
||||
// encountered while reading, parsing, or transforming the upstream response.
|
||||
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
@@ -226,6 +245,12 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// OpenRouter reasoning 字段转换:reasoning -> reasoning_content
|
||||
// 仅当启用转换为OpenAI兼容格式时执行(修改现有无条件转换)
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.OpenRouterConvertToOpenAI {
|
||||
convertOpenRouterReasoningFields(&simpleResponse)
|
||||
}
|
||||
|
||||
if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
|
||||
return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
|
||||
}
|
||||
@@ -271,6 +296,13 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
} else {
|
||||
// 对于 OpenRouter,仅在执行转换后重新序列化
|
||||
if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.OpenRouterConvertToOpenAI {
|
||||
responseBody, err = common.Marshal(simpleResponse)
|
||||
if err != nil {
|
||||
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case types.RelayFormatClaude:
|
||||
@@ -672,6 +704,10 @@ func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, res
|
||||
}
|
||||
}
|
||||
|
||||
// extractCachedTokensFromBody extracts a cached token count from a JSON response body.
|
||||
// It looks for cached token values in the following fields (in order): `usage.prompt_tokens_details.cached_tokens`,
|
||||
// `usage.cached_tokens`, and `usage.prompt_cache_hit_tokens`. It returns the first found value and `true`;
|
||||
// if none are present or the body cannot be parsed, it returns 0 and `false`.
|
||||
func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
if len(body) == 0 {
|
||||
return 0, false
|
||||
@@ -702,3 +738,18 @@ func extractCachedTokensFromBody(body []byte) (int, bool) {
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// convertOpenRouterReasoningFields 转换OpenRouter响应中的reasoning字段为reasoning_content
|
||||
// convertOpenRouterReasoningFields converts OpenRouter-style `reasoning` fields into `reasoning_content` for every choice's message in the provided OpenAITextResponse.
|
||||
// It modifies the response in place and is a no-op if `response` is nil or contains no choices.
|
||||
func convertOpenRouterReasoningFields(response *dto.OpenAITextResponse) {
|
||||
if response == nil || len(response.Choices) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// 遍历所有choices,对每个Message使用统一的泛型函数进行转换
|
||||
for i := range response.Choices {
|
||||
choice := &response.Choices[i]
|
||||
ConvertReasoningField(&choice.Message)
|
||||
}
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
if info.IsStream {
|
||||
var responseText string
|
||||
err, responseText = palmStreamHandler(c, resp)
|
||||
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
} else {
|
||||
usage, err = palmHandler(c, info, resp)
|
||||
}
|
||||
|
||||
@@ -130,7 +130,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
return service.ResponseText2Usage(c, responseText, info.UpstreamModelName, info.PromptTokens), nil
|
||||
}
|
||||
|
||||
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||||
|
||||
@@ -76,7 +76,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "claude") {
|
||||
a.RequestMode = RequestModeClaude
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") {
|
||||
} else if strings.Contains(info.UpstreamModelName, "llama") ||
|
||||
// open source models
|
||||
strings.Contains(info.UpstreamModelName, "-maas") {
|
||||
a.RequestMode = RequestModeLlama
|
||||
} else {
|
||||
a.RequestMode = RequestModeGemini
|
||||
@@ -220,6 +222,9 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
if a.AccountCredentials.ProjectID != "" {
|
||||
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
|
||||
}
|
||||
if strings.Contains(info.UpstreamModelName, "claude") {
|
||||
claude.CommonClaudeHeadersOperation(c, req, info)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -291,7 +296,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return vertexClaudeReq, nil
|
||||
} else if a.RequestMode == RequestModeGemini {
|
||||
geminiRequest, err := gemini.CovertGemini2OpenAI(c, *request, info)
|
||||
geminiRequest, err := gemini.CovertOpenAI2Gemini(c, *request, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -23,8 +23,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
contextKeyTTSRequest = "volcengine_tts_request"
|
||||
contextKeyResponseFormat = "response_format"
|
||||
contextKeyTTSRequest = "volcengine_tts_request"
|
||||
contextKeyResponseFormat = "response_format"
|
||||
DoubaoCodingPlan = "doubao-coding-plan"
|
||||
DoubaoCodingPlanClaudeBaseURL = "https://ark.cn-beijing.volces.com/api/coding"
|
||||
DoubaoCodingPlanOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/coding/v3"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -238,6 +241,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
|
||||
switch info.RelayFormat {
|
||||
case types.RelayFormatClaude:
|
||||
if baseUrl == DoubaoCodingPlan {
|
||||
return fmt.Sprintf("%s/v1/messages", DoubaoCodingPlanClaudeBaseURL), nil
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
|
||||
}
|
||||
@@ -245,6 +251,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
default:
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
if baseUrl == DoubaoCodingPlan {
|
||||
return fmt.Sprintf("%s/chat/completions", DoubaoCodingPlanOpenAIBaseURL), nil
|
||||
}
|
||||
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
||||
return fmt.Sprintf("%s/api/v3/bots/chat/completions", baseUrl), nil
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
|
||||
})
|
||||
|
||||
if !containStreamUsage {
|
||||
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||
usage.CompletionTokens += toolCount * 7
|
||||
}
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -30,7 +30,7 @@ type ParamOperation struct {
|
||||
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
||||
}
|
||||
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
||||
if len(paramOverride) == 0 {
|
||||
return jsonData, nil
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
|
||||
// 尝试断言为操作格式
|
||||
if operations, ok := tryParseOperations(paramOverride); ok {
|
||||
// 使用新方法
|
||||
result, err := applyOperations(string(jsonData), operations)
|
||||
result, err := applyOperations(string(jsonData), operations, conditionContext)
|
||||
return []byte(result), err
|
||||
}
|
||||
|
||||
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
||||
if len(conditions) == 0 {
|
||||
return true, nil // 没有条件,直接通过
|
||||
}
|
||||
results := make([]bool, len(conditions))
|
||||
for i, condition := range conditions {
|
||||
result, err := checkSingleCondition(jsonStr, condition)
|
||||
result, err := checkSingleCondition(jsonStr, contextJSON, condition)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
|
||||
}
|
||||
}
|
||||
|
||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||
func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
|
||||
// 处理负数索引
|
||||
path := processNegativeIndex(jsonStr, condition.Path)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
if !value.Exists() && contextJSON != "" {
|
||||
value = gjson.Get(contextJSON, condition.Path)
|
||||
}
|
||||
if !value.Exists() {
|
||||
if condition.PassMissingKey {
|
||||
return true, nil
|
||||
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
|
||||
}
|
||||
|
||||
// 利用gjson的类型解析
|
||||
targetBytes, err := json.Marshal(condition.Value)
|
||||
targetBytes, err := common.Marshal(condition.Value)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
||||
}
|
||||
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
|
||||
// applyOperationsLegacy 原参数覆盖方法
|
||||
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
||||
reqMap := make(map[string]interface{})
|
||||
err := json.Unmarshal(jsonData, &reqMap)
|
||||
err := common.Unmarshal(jsonData, &reqMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
|
||||
reqMap[key] = value
|
||||
}
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
return common.Marshal(reqMap)
|
||||
}
|
||||
|
||||
func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
|
||||
func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
|
||||
var contextJSON string
|
||||
if conditionContext != nil && len(conditionContext) > 0 {
|
||||
ctxBytes, err := common.Marshal(conditionContext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal condition context: %v", err)
|
||||
}
|
||||
contextJSON = string(ctxBytes)
|
||||
}
|
||||
|
||||
result := jsonStr
|
||||
for _, op := range operations {
|
||||
// 检查条件是否满足
|
||||
ok, err := checkConditions(result, op.Conditions, op.Logic)
|
||||
ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
var currentMap, newMap map[string]interface{}
|
||||
|
||||
// 解析当前值
|
||||
if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 解析新值
|
||||
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
case map[string]interface{}:
|
||||
newMap = v
|
||||
default:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
jsonBytes, _ := common.Marshal(v)
|
||||
if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
||||
}
|
||||
return sjson.Set(jsonStr, path, result)
|
||||
}
|
||||
|
||||
// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
|
||||
// 目前内置以下字段:
|
||||
// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
|
||||
// - upstream_model:始终为通道映射后的上游模型名。
|
||||
// - original_model:请求最初指定的模型名。
|
||||
func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
||||
if info == nil || info.ChannelMeta == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := make(map[string]interface{})
|
||||
if info.UpstreamModelName != "" {
|
||||
ctx["model"] = info.UpstreamModelName
|
||||
ctx["upstream_model"] = info.UpstreamModelName
|
||||
}
|
||||
if info.OriginModelName != "" {
|
||||
ctx["original_model"] = info.OriginModelName
|
||||
if _, exists := ctx["model"]; !exists {
|
||||
ctx["model"] = info.OriginModelName
|
||||
}
|
||||
}
|
||||
|
||||
if len(ctx) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
|
||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
|
||||
@@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
|
||||
// apply param override
|
||||
if len(info.ParamOverride) > 0 {
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
|
||||
jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
|
||||
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
|
||||
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
|
||||
apiRouter.GET("/oauth/discord", middleware.CriticalRateLimit(), controller.DiscordOAuth)
|
||||
apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), controller.OidcAuth)
|
||||
apiRouter.GET("/oauth/linuxdo", middleware.CriticalRateLimit(), controller.LinuxdoOAuth)
|
||||
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// return image.Config, format, clean base64 string, error
|
||||
func DecodeBase64ImageData(base64String string) (image.Config, string, string, error) {
|
||||
// 去除base64数据的URL前缀(如果有)
|
||||
if idx := strings.Index(base64String, ","); idx != -1 {
|
||||
|
||||
@@ -62,6 +62,12 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m
|
||||
adminInfo["is_multi_key"] = true
|
||||
adminInfo["multi_key_index"] = common.GetContextKeyInt(ctx, constant.ContextKeyChannelMultiKeyIndex)
|
||||
}
|
||||
|
||||
isLocalCountTokens := common.GetContextKeyBool(ctx, constant.ContextKeyLocalCountTokens)
|
||||
if isLocalCountTokens {
|
||||
adminInfo["local_count_tokens"] = isLocalCountTokens
|
||||
}
|
||||
|
||||
other["admin_info"] = adminInfo
|
||||
appendRequestPath(ctx, relayInfo, other)
|
||||
return other
|
||||
|
||||
@@ -143,6 +143,12 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
if fileMeta.Detail == "low" && !isPatchBased {
|
||||
return baseTokens, nil
|
||||
}
|
||||
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
if !constant.GetMediaTokenNotStream && !stream {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
@@ -150,10 +156,6 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||||
fileMeta.Detail = "high"
|
||||
}
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
// Decode image to get dimensions
|
||||
var config image.Config
|
||||
@@ -256,16 +258,15 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er
|
||||
}
|
||||
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
// 是否统计token
|
||||
if !constant.CountToken {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
if meta == nil {
|
||||
return 0, errors.New("token count meta is nil")
|
||||
}
|
||||
|
||||
if !constant.GetMediaToken {
|
||||
return 0, nil
|
||||
}
|
||||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||||
return 0, nil
|
||||
}
|
||||
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -316,9 +317,19 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
if shouldFetchFiles {
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
// 是否本地计算媒体token数量
|
||||
if !constant.GetMediaToken {
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
// 是否在非流模式下本地计算媒体token数量
|
||||
if !constant.GetMediaTokenNotStream && !info.IsStream {
|
||||
shouldFetchFiles = false
|
||||
}
|
||||
|
||||
for _, file := range meta.Files {
|
||||
if strings.HasPrefix(file.OriginData, "http") {
|
||||
if shouldFetchFiles {
|
||||
mineType, err := GetFileTypeFromUrl(c, file.OriginData, "token_counter")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting file base64 from url: %v", err)
|
||||
@@ -333,28 +344,28 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
} else if strings.HasPrefix(file.OriginData, "data:") {
|
||||
// get mime type from base64 header
|
||||
parts := strings.SplitN(file.OriginData, ",", 2)
|
||||
if len(parts) >= 1 {
|
||||
header := parts[0]
|
||||
// Extract mime type from "data:mime/type;base64" format
|
||||
if strings.Contains(header, ":") && strings.Contains(header, ";") {
|
||||
mimeStart := strings.Index(header, ":") + 1
|
||||
mimeEnd := strings.Index(header, ";")
|
||||
if mimeStart < mimeEnd {
|
||||
mineType := header[mimeStart:mimeEnd]
|
||||
if strings.HasPrefix(mineType, "image/") {
|
||||
file.FileType = types.FileTypeImage
|
||||
} else if strings.HasPrefix(mineType, "video/") {
|
||||
file.FileType = types.FileTypeVideo
|
||||
} else if strings.HasPrefix(mineType, "audio/") {
|
||||
file.FileType = types.FileTypeAudio
|
||||
} else {
|
||||
file.FileType = types.FileTypeFile
|
||||
}
|
||||
file.MimeType = mineType
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -365,7 +376,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 256
|
||||
tkm += 520 // gemini per input image tokens
|
||||
} else {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
//func GetPromptTokens(textRequest dto.GeneralOpenAIRequest, relayMode int) (int, error) {
|
||||
@@ -16,7 +19,8 @@ import (
|
||||
// return 0, errors.New("unknown relay mode")
|
||||
//}
|
||||
|
||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
func ResponseText2Usage(c *gin.Context, responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||
common.SetContextKey(c, constant.ContextKeyLocalCountTokens, true)
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
ctkm := CountTextToken(responseText, modeName)
|
||||
|
||||
@@ -11,13 +11,13 @@ type GeminiSettings struct {
|
||||
SupportedImagineModels []string `json:"supported_imagine_models"`
|
||||
ThinkingAdapterEnabled bool `json:"thinking_adapter_enabled"`
|
||||
ThinkingAdapterBudgetTokensPercentage float64 `json:"thinking_adapter_budget_tokens_percentage"`
|
||||
FunctionCallThoughtSignatureEnabled bool `json:"function_call_thought_signature_enabled"`
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultGeminiSettings = GeminiSettings{
|
||||
SafetySettings: map[string]string{
|
||||
"default": "OFF",
|
||||
"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
|
||||
"default": "OFF",
|
||||
},
|
||||
VersionSettings: map[string]string{
|
||||
"default": "v1beta",
|
||||
@@ -29,6 +29,7 @@ var defaultGeminiSettings = GeminiSettings{
|
||||
},
|
||||
ThinkingAdapterEnabled: false,
|
||||
ThinkingAdapterBudgetTokensPercentage: 0.6,
|
||||
FunctionCallThoughtSignatureEnabled: true,
|
||||
}
|
||||
|
||||
// 全局实例
|
||||
|
||||
@@ -598,6 +598,11 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
|
||||
return 2.5 / 0.3, false
|
||||
} else if strings.HasPrefix(name, "gemini-robotics-er-1.5") {
|
||||
return 2.5 / 0.3, false
|
||||
} else if strings.HasPrefix(name, "gemini-3-pro") {
|
||||
if strings.HasPrefix(name, "gemini-3-pro-image") {
|
||||
return 60, false
|
||||
}
|
||||
return 6, false
|
||||
}
|
||||
return 4, false
|
||||
}
|
||||
|
||||
21
setting/system_setting/discord.go
Normal file
21
setting/system_setting/discord.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package system_setting
|
||||
|
||||
import "github.com/QuantumNous/new-api/setting/config"
|
||||
|
||||
type DiscordSettings struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
ClientId string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultDiscordSettings = DiscordSettings{}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("discord", &defaultDiscordSettings)
|
||||
}
|
||||
|
||||
func GetDiscordSettings() *DiscordSettings {
|
||||
return &defaultDiscordSettings
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"lockfileVersion": 1,
|
||||
"configVersion": 0,
|
||||
"workspaces": {
|
||||
"": {
|
||||
"name": "react-template",
|
||||
|
||||
@@ -192,6 +192,14 @@ function App() {
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/discord'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
||||
<OAuth2Callback type='discord'></OAuth2Callback>
|
||||
</Suspense>
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
path='/oauth/oidc'
|
||||
element={
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import React, { useContext, useEffect, useRef, useState } from 'react';
|
||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||
import { UserContext } from '../../context/User';
|
||||
import {
|
||||
@@ -30,6 +30,7 @@ import {
|
||||
getSystemName,
|
||||
setUserData,
|
||||
onGitHubOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
prepareCredentialRequestOptions,
|
||||
@@ -53,6 +54,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
|
||||
import LinuxDoIcon from '../common/logo/LinuxDoIcon';
|
||||
import TwoFAVerification from './TwoFAVerification';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { SiDiscord }from 'react-icons/si';
|
||||
|
||||
const LoginForm = () => {
|
||||
let navigate = useNavigate();
|
||||
@@ -73,6 +75,7 @@ const LoginForm = () => {
|
||||
const [showEmailLogin, setShowEmailLogin] = useState(false);
|
||||
const [wechatLoading, setWechatLoading] = useState(false);
|
||||
const [githubLoading, setGithubLoading] = useState(false);
|
||||
const [discordLoading, setDiscordLoading] = useState(false);
|
||||
const [oidcLoading, setOidcLoading] = useState(false);
|
||||
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
|
||||
const [emailLoginLoading, setEmailLoginLoading] = useState(false);
|
||||
@@ -87,6 +90,9 @@ const LoginForm = () => {
|
||||
const [agreedToTerms, setAgreedToTerms] = useState(false);
|
||||
const [hasUserAgreement, setHasUserAgreement] = useState(false);
|
||||
const [hasPrivacyPolicy, setHasPrivacyPolicy] = useState(false);
|
||||
const [githubButtonText, setGithubButtonText] = useState('使用 GitHub 继续');
|
||||
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
|
||||
const githubTimeoutRef = useRef(null);
|
||||
|
||||
const logo = getLogo();
|
||||
const systemName = getSystemName();
|
||||
@@ -116,6 +122,12 @@ const LoginForm = () => {
|
||||
isPasskeySupported()
|
||||
.then(setPasskeySupported)
|
||||
.catch(() => setPasskeySupported(false));
|
||||
|
||||
return () => {
|
||||
if (githubTimeoutRef.current) {
|
||||
clearTimeout(githubTimeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -267,7 +279,20 @@ const LoginForm = () => {
|
||||
showInfo(t('请先阅读并同意用户协议和隐私政策'));
|
||||
return;
|
||||
}
|
||||
if (githubButtonDisabled) {
|
||||
return;
|
||||
}
|
||||
setGithubLoading(true);
|
||||
setGithubButtonDisabled(true);
|
||||
setGithubButtonText(t('正在跳转 GitHub...'));
|
||||
if (githubTimeoutRef.current) {
|
||||
clearTimeout(githubTimeoutRef.current);
|
||||
}
|
||||
githubTimeoutRef.current = setTimeout(() => {
|
||||
setGithubLoading(false);
|
||||
setGithubButtonText(t('请求超时,请刷新页面后重新发起 GitHub 登录'));
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
onGitHubOAuthClicked(status.github_client_id);
|
||||
} finally {
|
||||
@@ -276,6 +301,21 @@ const LoginForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的Discord登录点击处理
|
||||
const handleDiscordClick = () => {
|
||||
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
|
||||
showInfo(t('请先阅读并同意用户协议和隐私政策'));
|
||||
return;
|
||||
}
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
} finally {
|
||||
// 由于重定向,这里不会执行到,但为了完整性添加
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
};
|
||||
|
||||
// 包装的OIDC登录点击处理
|
||||
const handleOIDCClick = () => {
|
||||
if ((hasUserAgreement || hasPrivacyPolicy) && !agreedToTerms) {
|
||||
@@ -444,8 +484,22 @@ const LoginForm = () => {
|
||||
icon={<IconGithubLogo size='large' />}
|
||||
onClick={handleGitHubClick}
|
||||
loading={githubLoading}
|
||||
disabled={githubButtonDisabled}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 GitHub 继续')}</span>
|
||||
<span className='ml-3'>{githubButtonText}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.discord_oauth && (
|
||||
<Button
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
|
||||
onClick={handleDiscordClick}
|
||||
loading={discordLoading}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 Discord 继续')}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@@ -691,6 +745,7 @@ const LoginForm = () => {
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
@@ -826,6 +881,7 @@ const LoginForm = () => {
|
||||
{showEmailLogin ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useContext, useEffect, useState } from 'react';
|
||||
import React, { useContext, useEffect, useRef, useState } from 'react';
|
||||
import { Link, useNavigate } from 'react-router-dom';
|
||||
import {
|
||||
API,
|
||||
@@ -28,6 +28,7 @@ import {
|
||||
updateAPI,
|
||||
getSystemName,
|
||||
setUserData,
|
||||
onDiscordOAuthClicked,
|
||||
} from '../../helpers';
|
||||
import Turnstile from 'react-turnstile';
|
||||
import { Button, Card, Checkbox, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui';
|
||||
@@ -51,6 +52,7 @@ import WeChatIcon from '../common/logo/WeChatIcon';
|
||||
import TelegramLoginButton from 'react-telegram-login/src';
|
||||
import { UserContext } from '../../context/User';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { SiDiscord } from 'react-icons/si';
|
||||
|
||||
const RegisterForm = () => {
|
||||
let navigate = useNavigate();
|
||||
@@ -72,6 +74,7 @@ const RegisterForm = () => {
|
||||
const [showEmailRegister, setShowEmailRegister] = useState(false);
|
||||
const [wechatLoading, setWechatLoading] = useState(false);
|
||||
const [githubLoading, setGithubLoading] = useState(false);
|
||||
const [discordLoading, setDiscordLoading] = useState(false);
|
||||
const [oidcLoading, setOidcLoading] = useState(false);
|
||||
const [linuxdoLoading, setLinuxdoLoading] = useState(false);
|
||||
const [emailRegisterLoading, setEmailRegisterLoading] = useState(false);
|
||||
@@ -85,6 +88,9 @@ const RegisterForm = () => {
|
||||
const [agreedToTerms, setAgreedToTerms] = useState(false);
|
||||
const [hasUserAgreement, setHasUserAgreement] = useState(false);
|
||||
const [hasPrivacyPolicy, setHasPrivacyPolicy] = useState(false);
|
||||
const [githubButtonText, setGithubButtonText] = useState('使用 GitHub 继续');
|
||||
const [githubButtonDisabled, setGithubButtonDisabled] = useState(false);
|
||||
const githubTimeoutRef = useRef(null);
|
||||
|
||||
const logo = getLogo();
|
||||
const systemName = getSystemName();
|
||||
@@ -128,6 +134,14 @@ const RegisterForm = () => {
|
||||
return () => clearInterval(countdownInterval); // Clean up on unmount
|
||||
}, [disableButton, countdown]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (githubTimeoutRef.current) {
|
||||
clearTimeout(githubTimeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const onWeChatLoginClicked = () => {
|
||||
setWechatLoading(true);
|
||||
setShowWeChatLoginModal(true);
|
||||
@@ -232,7 +246,20 @@ const RegisterForm = () => {
|
||||
};
|
||||
|
||||
const handleGitHubClick = () => {
|
||||
if (githubButtonDisabled) {
|
||||
return;
|
||||
}
|
||||
setGithubLoading(true);
|
||||
setGithubButtonDisabled(true);
|
||||
setGithubButtonText(t('正在跳转 GitHub...'));
|
||||
if (githubTimeoutRef.current) {
|
||||
clearTimeout(githubTimeoutRef.current);
|
||||
}
|
||||
githubTimeoutRef.current = setTimeout(() => {
|
||||
setGithubLoading(false);
|
||||
setGithubButtonText(t('请求超时,请刷新页面后重新发起 GitHub 登录'));
|
||||
setGithubButtonDisabled(true);
|
||||
}, 20000);
|
||||
try {
|
||||
onGitHubOAuthClicked(status.github_client_id);
|
||||
} finally {
|
||||
@@ -240,6 +267,15 @@ const RegisterForm = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const handleDiscordClick = () => {
|
||||
setDiscordLoading(true);
|
||||
try {
|
||||
onDiscordOAuthClicked(status.discord_client_id);
|
||||
} finally {
|
||||
setTimeout(() => setDiscordLoading(false), 3000);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOIDCClick = () => {
|
||||
setOidcLoading(true);
|
||||
try {
|
||||
@@ -347,8 +383,22 @@ const RegisterForm = () => {
|
||||
icon={<IconGithubLogo size='large' />}
|
||||
onClick={handleGitHubClick}
|
||||
loading={githubLoading}
|
||||
disabled={githubButtonDisabled}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 GitHub 继续')}</span>
|
||||
<span className='ml-3'>{githubButtonText}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
{status.discord_oauth && (
|
||||
<Button
|
||||
theme='outline'
|
||||
className='w-full h-12 flex items-center justify-center !rounded-full border border-gray-200 hover:bg-gray-50 transition-colors'
|
||||
type='tertiary'
|
||||
icon={<SiDiscord style={{ color: '#5865F2', width: '20px', height: '20px' }} />}
|
||||
onClick={handleDiscordClick}
|
||||
loading={discordLoading}
|
||||
>
|
||||
<span className='ml-3'>{t('使用 Discord 继续')}</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@@ -566,6 +616,7 @@ const RegisterForm = () => {
|
||||
</Form>
|
||||
|
||||
{(status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
@@ -661,6 +712,7 @@ const RegisterForm = () => {
|
||||
{showEmailRegister ||
|
||||
!(
|
||||
status.github_oauth ||
|
||||
status.discord_oauth ||
|
||||
status.oidc_enabled ||
|
||||
status.wechat_login ||
|
||||
status.linuxdo_oauth ||
|
||||
|
||||
@@ -20,7 +20,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
import React from 'react';
|
||||
import { Button, Dropdown } from '@douyinfe/semi-ui';
|
||||
import { Languages } from 'lucide-react';
|
||||
import { CN, GB, FR, RU, JP } from 'country-flag-icons/react/3x2';
|
||||
import { CN, GB, FR, RU, JP, VN } from 'country-flag-icons/react/3x2';
|
||||
|
||||
const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
|
||||
return (
|
||||
@@ -65,6 +65,13 @@ const LanguageSelector = ({ currentLang, onLanguageChange, t }) => {
|
||||
<RU title='Русский' className='!w-5 !h-auto' />
|
||||
<span>Русский</span>
|
||||
</Dropdown.Item>
|
||||
<Dropdown.Item
|
||||
onClick={() => onLanguageChange('vi')}
|
||||
className={`!flex !items-center !gap-2 !px-3 !py-1.5 !text-sm !text-semi-color-text-0 dark:!text-gray-200 ${currentLang === 'vi' ? '!bg-semi-color-primary-light-default dark:!bg-blue-600 !font-semibold' : 'hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-600'}`}
|
||||
>
|
||||
<VN title='Tiếng Việt' className='!w-5 !h-auto' />
|
||||
<span>Tiếng Việt</span>
|
||||
</Dropdown.Item>
|
||||
</Dropdown.Menu>
|
||||
}
|
||||
>
|
||||
|
||||
@@ -52,6 +52,9 @@ const SystemSetting = () => {
|
||||
GitHubOAuthEnabled: '',
|
||||
GitHubClientId: '',
|
||||
GitHubClientSecret: '',
|
||||
'discord.enabled': '',
|
||||
'discord.client_id': '',
|
||||
'discord.client_secret': '',
|
||||
'oidc.enabled': '',
|
||||
'oidc.client_id': '',
|
||||
'oidc.client_secret': '',
|
||||
@@ -179,6 +182,7 @@ const SystemSetting = () => {
|
||||
case 'EmailAliasRestrictionEnabled':
|
||||
case 'SMTPSSLEnabled':
|
||||
case 'LinuxDOOAuthEnabled':
|
||||
case 'discord.enabled':
|
||||
case 'oidc.enabled':
|
||||
case 'passkey.enabled':
|
||||
case 'passkey.allow_insecure_origin':
|
||||
@@ -473,6 +477,27 @@ const SystemSetting = () => {
|
||||
}
|
||||
};
|
||||
|
||||
const submitDiscordOAuth = async () => {
|
||||
const options = [];
|
||||
|
||||
if (originInputs['discord.client_id'] !== inputs['discord.client_id']) {
|
||||
options.push({ key: 'discord.client_id', value: inputs['discord.client_id'] });
|
||||
}
|
||||
if (
|
||||
originInputs['discord.client_secret'] !== inputs['discord.client_secret'] &&
|
||||
inputs['discord.client_secret'] !== ''
|
||||
) {
|
||||
options.push({
|
||||
key: 'discord.client_secret',
|
||||
value: inputs['discord.client_secret'],
|
||||
});
|
||||
}
|
||||
|
||||
if (options.length > 0) {
|
||||
await updateOptions(options);
|
||||
}
|
||||
};
|
||||
|
||||
const submitOIDCSettings = async () => {
|
||||
if (inputs['oidc.well_known'] && inputs['oidc.well_known'] !== '') {
|
||||
if (
|
||||
@@ -1014,6 +1039,15 @@ const SystemSetting = () => {
|
||||
>
|
||||
{t('允许通过 GitHub 账户登录 & 注册')}
|
||||
</Form.Checkbox>
|
||||
<Form.Checkbox
|
||||
field='discord.enabled'
|
||||
noLabel
|
||||
onChange={(e) =>
|
||||
handleCheckboxChange('discord.enabled', e)
|
||||
}
|
||||
>
|
||||
{t('允许通过 Discord 账户登录 & 注册')}
|
||||
</Form.Checkbox>
|
||||
<Form.Checkbox
|
||||
field='LinuxDOOAuthEnabled'
|
||||
noLabel
|
||||
@@ -1410,6 +1444,37 @@ const SystemSetting = () => {
|
||||
</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
<Card>
|
||||
<Form.Section text={t('配置 Discord OAuth')}>
|
||||
<Text>{t('用以支持通过 Discord 进行登录注册')}</Text>
|
||||
<Banner
|
||||
type='info'
|
||||
description={`${t('Homepage URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')},${t('Authorization callback URL 填')} ${inputs.ServerAddress ? inputs.ServerAddress : t('网站地址')}/oauth/discord`}
|
||||
style={{ marginBottom: 20, marginTop: 16 }}
|
||||
/>
|
||||
<Row
|
||||
gutter={{ xs: 8, sm: 16, md: 24, lg: 24, xl: 24, xxl: 24 }}
|
||||
>
|
||||
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
|
||||
<Form.Input
|
||||
field="['discord.client_id']"
|
||||
label={t('Discord Client ID')}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} sm={24} md={12} lg={12} xl={12}>
|
||||
<Form.Input
|
||||
field="['discord.client_secret']"
|
||||
label={t('Discord Client Secret')}
|
||||
type='password'
|
||||
placeholder={t('敏感信息不会发送到前端显示')}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Button onClick={submitDiscordOAuth}>
|
||||
{t('保存 Discord OAuth 设置')}
|
||||
</Button>
|
||||
</Form.Section>
|
||||
</Card>
|
||||
<Card>
|
||||
<Form.Section text={t('配置 Linux DO OAuth')}>
|
||||
<Text>
|
||||
|
||||
@@ -38,13 +38,14 @@ import {
|
||||
IconLock,
|
||||
IconDelete,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { SiTelegram, SiWechat, SiLinux } from 'react-icons/si';
|
||||
import { SiTelegram, SiWechat, SiLinux, SiDiscord } from 'react-icons/si';
|
||||
import { UserPlus, ShieldCheck } from 'lucide-react';
|
||||
import TelegramLoginButton from 'react-telegram-login';
|
||||
import {
|
||||
onGitHubOAuthClicked,
|
||||
onOIDCClicked,
|
||||
onLinuxDOOAuthClicked,
|
||||
onDiscordOAuthClicked,
|
||||
} from '../../../../helpers';
|
||||
import TwoFASetting from '../components/TwoFASetting';
|
||||
|
||||
@@ -247,6 +248,47 @@ const AccountManagement = ({
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* Discord绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
<div className='flex items-center flex-1 min-w-0'>
|
||||
<div className='w-10 h-10 rounded-full bg-slate-100 dark:bg-slate-700 flex items-center justify-center mr-3 flex-shrink-0'>
|
||||
<SiDiscord
|
||||
size={20}
|
||||
className='text-slate-600 dark:text-slate-300'
|
||||
/>
|
||||
</div>
|
||||
<div className='flex-1 min-w-0'>
|
||||
<div className='font-medium text-gray-900'>
|
||||
{t('Discord')}
|
||||
</div>
|
||||
<div className='text-sm text-gray-500 truncate'>
|
||||
{renderAccountInfo(
|
||||
userState.user?.discord_id,
|
||||
t('Discord ID'),
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex-shrink-0'>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='outline'
|
||||
size='small'
|
||||
onClick={() =>
|
||||
onDiscordOAuthClicked(status.discord_client_id)
|
||||
}
|
||||
disabled={
|
||||
isBound(userState.user?.discord_id) ||
|
||||
!status.discord_oauth
|
||||
}
|
||||
>
|
||||
{status.discord_oauth ? t('绑定') : t('未启用')}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{/* OIDC绑定 */}
|
||||
<Card className='!rounded-xl'>
|
||||
<div className='flex items-center justify-between gap-3'>
|
||||
|
||||
@@ -189,6 +189,31 @@ const EditChannelModal = (props) => {
|
||||
const [useManualInput, setUseManualInput] = useState(false); // 是否使用手动输入模式
|
||||
const [keyMode, setKeyMode] = useState('append'); // 密钥模式:replace(覆盖)或 append(追加)
|
||||
const [isEnterpriseAccount, setIsEnterpriseAccount] = useState(false); // 是否为企业账户
|
||||
const [doubaoApiEditUnlocked, setDoubaoApiEditUnlocked] = useState(false); // 豆包渠道自定义 API 地址隐藏入口
|
||||
const redirectModelList = useMemo(() => {
|
||||
const mapping = inputs.model_mapping;
|
||||
if (typeof mapping !== 'string') return [];
|
||||
const trimmed = mapping.trim();
|
||||
if (!trimmed) return [];
|
||||
try {
|
||||
const parsed = JSON.parse(trimmed);
|
||||
if (
|
||||
!parsed ||
|
||||
typeof parsed !== 'object' ||
|
||||
Array.isArray(parsed)
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
const values = Object.values(parsed)
|
||||
.map((value) =>
|
||||
typeof value === 'string' ? value.trim() : undefined,
|
||||
)
|
||||
.filter((value) => value);
|
||||
return Array.from(new Set(values));
|
||||
} catch (error) {
|
||||
return [];
|
||||
}
|
||||
}, [inputs.model_mapping]);
|
||||
|
||||
// 密钥显示状态
|
||||
const [keyDisplayState, setKeyDisplayState] = useState({
|
||||
@@ -218,6 +243,9 @@ const EditChannelModal = (props) => {
|
||||
'channelExtraSettings',
|
||||
];
|
||||
const formContainerRef = useRef(null);
|
||||
const doubaoApiClickCountRef = useRef(0);
|
||||
const initialModelsRef = useRef([]);
|
||||
const initialModelMappingRef = useRef('');
|
||||
|
||||
// 2FA状态更新辅助函数
|
||||
const updateTwoFAState = (updates) => {
|
||||
@@ -306,6 +334,20 @@ const EditChannelModal = (props) => {
|
||||
scrollToSection(availableSections[newIndex]);
|
||||
};
|
||||
|
||||
const handleApiConfigSecretClick = () => {
|
||||
if (inputs.type !== 45) return;
|
||||
const next = doubaoApiClickCountRef.current + 1;
|
||||
doubaoApiClickCountRef.current = next;
|
||||
if (next >= 10) {
|
||||
setDoubaoApiEditUnlocked((unlocked) => {
|
||||
if (!unlocked) {
|
||||
showInfo(t('已解锁豆包自定义 API 地址编辑'));
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 渠道额外设置状态
|
||||
const [channelSettings, setChannelSettings] = useState({
|
||||
force_format: false,
|
||||
@@ -579,6 +621,10 @@ const EditChannelModal = (props) => {
|
||||
system_prompt: data.system_prompt,
|
||||
system_prompt_override: data.system_prompt_override || false,
|
||||
});
|
||||
initialModelsRef.current = (data.models || [])
|
||||
.map((model) => (model || '').trim())
|
||||
.filter(Boolean);
|
||||
initialModelMappingRef.current = data.model_mapping || '';
|
||||
// console.log(data);
|
||||
} else {
|
||||
showError(message);
|
||||
@@ -724,6 +770,13 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (inputs.type !== 45) {
|
||||
doubaoApiClickCountRef.current = 0;
|
||||
setDoubaoApiEditUnlocked(false);
|
||||
}
|
||||
}, [inputs.type]);
|
||||
|
||||
useEffect(() => {
|
||||
const modelMap = new Map();
|
||||
|
||||
@@ -807,6 +860,13 @@ const EditChannelModal = (props) => {
|
||||
}
|
||||
}, [props.visible, channelId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isEdit) {
|
||||
initialModelsRef.current = [];
|
||||
initialModelMappingRef.current = '';
|
||||
}
|
||||
}, [isEdit, props.visible]);
|
||||
|
||||
// 统一的模态框重置函数
|
||||
const resetModalState = () => {
|
||||
formApiRef.current?.reset();
|
||||
@@ -823,6 +883,9 @@ const EditChannelModal = (props) => {
|
||||
setKeyMode('append');
|
||||
// 重置企业账户状态
|
||||
setIsEnterpriseAccount(false);
|
||||
// 重置豆包隐藏入口状态
|
||||
setDoubaoApiEditUnlocked(false);
|
||||
doubaoApiClickCountRef.current = 0;
|
||||
// 清空表单中的key_mode字段
|
||||
if (formApiRef.current) {
|
||||
formApiRef.current.setValue('key_mode', undefined);
|
||||
@@ -877,6 +940,80 @@ const EditChannelModal = (props) => {
|
||||
})();
|
||||
};
|
||||
|
||||
const confirmMissingModelMappings = (missingModels) =>
|
||||
new Promise((resolve) => {
|
||||
const modal = Modal.confirm({
|
||||
title: t('模型未加入列表,可能无法调用'),
|
||||
content: (
|
||||
<div className='text-sm leading-6'>
|
||||
<div>
|
||||
{t(
|
||||
'模型重定向里的下列模型尚未添加到“模型”列表,调用时会因为缺少可用模型而失败:',
|
||||
)}
|
||||
</div>
|
||||
<div className='font-mono text-xs break-all text-red-600 mt-1'>
|
||||
{missingModels.join(', ')}
|
||||
</div>
|
||||
<div className='mt-2'>
|
||||
{t(
|
||||
'你可以在“自定义模型名称”处手动添加它们,然后点击填入后再提交,或者直接使用下方操作自动处理。',
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
centered: true,
|
||||
footer: (
|
||||
<Space align='center' className='w-full justify-end'>
|
||||
<Button
|
||||
type='tertiary'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('cancel');
|
||||
}}
|
||||
>
|
||||
{t('返回修改')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='light'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('submit');
|
||||
}}
|
||||
>
|
||||
{t('直接提交')}
|
||||
</Button>
|
||||
<Button
|
||||
type='primary'
|
||||
theme='solid'
|
||||
onClick={() => {
|
||||
modal.destroy();
|
||||
resolve('add');
|
||||
}}
|
||||
>
|
||||
{t('添加后提交')}
|
||||
</Button>
|
||||
</Space>
|
||||
),
|
||||
});
|
||||
});
|
||||
|
||||
const hasModelConfigChanged = (normalizedModels, modelMappingStr) => {
|
||||
if (!isEdit) return true;
|
||||
const initialModels = initialModelsRef.current;
|
||||
if (normalizedModels.length !== initialModels.length) {
|
||||
return true;
|
||||
}
|
||||
for (let i = 0; i < normalizedModels.length; i++) {
|
||||
if (normalizedModels[i] !== initialModels[i]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
const normalizedMapping = (modelMappingStr || '').trim();
|
||||
const initialMapping = (initialModelMappingRef.current || '').trim();
|
||||
return normalizedMapping !== initialMapping;
|
||||
};
|
||||
|
||||
const submit = async () => {
|
||||
const formValues = formApiRef.current ? formApiRef.current.getValues() : {};
|
||||
let localInputs = { ...formValues };
|
||||
@@ -960,14 +1097,55 @@ const EditChannelModal = (props) => {
|
||||
showInfo(t('请输入API地址!'));
|
||||
return;
|
||||
}
|
||||
if (
|
||||
localInputs.model_mapping &&
|
||||
localInputs.model_mapping !== '' &&
|
||||
!verifyJSON(localInputs.model_mapping)
|
||||
) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
const hasModelMapping =
|
||||
typeof localInputs.model_mapping === 'string' &&
|
||||
localInputs.model_mapping.trim() !== '';
|
||||
let parsedModelMapping = null;
|
||||
if (hasModelMapping) {
|
||||
if (!verifyJSON(localInputs.model_mapping)) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
}
|
||||
try {
|
||||
parsedModelMapping = JSON.parse(localInputs.model_mapping);
|
||||
} catch (error) {
|
||||
showInfo(t('模型映射必须是合法的 JSON 格式!'));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const normalizedModels = (localInputs.models || [])
|
||||
.map((model) => (model || '').trim())
|
||||
.filter(Boolean);
|
||||
localInputs.models = normalizedModels;
|
||||
|
||||
if (
|
||||
parsedModelMapping &&
|
||||
typeof parsedModelMapping === 'object' &&
|
||||
!Array.isArray(parsedModelMapping)
|
||||
) {
|
||||
const modelSet = new Set(normalizedModels);
|
||||
const missingModels = Object.keys(parsedModelMapping)
|
||||
.map((key) => (key || '').trim())
|
||||
.filter((key) => key && !modelSet.has(key));
|
||||
const shouldPromptMissing =
|
||||
missingModels.length > 0 &&
|
||||
hasModelConfigChanged(normalizedModels, localInputs.model_mapping);
|
||||
if (shouldPromptMissing) {
|
||||
const confirmAction = await confirmMissingModelMappings(missingModels);
|
||||
if (confirmAction === 'cancel') {
|
||||
return;
|
||||
}
|
||||
if (confirmAction === 'add') {
|
||||
const updatedModels = Array.from(
|
||||
new Set([...normalizedModels, ...missingModels]),
|
||||
);
|
||||
localInputs.models = updatedModels;
|
||||
handleInputChange('models', updatedModels);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (localInputs.base_url && localInputs.base_url.endsWith('/')) {
|
||||
localInputs.base_url = localInputs.base_url.slice(
|
||||
0,
|
||||
@@ -1959,7 +2137,10 @@ const EditChannelModal = (props) => {
|
||||
<div ref={(el) => (formSectionRefs.current.apiConfig = el)}>
|
||||
<Card className='!rounded-2xl shadow-sm border-0 mb-6'>
|
||||
{/* Header: API Config */}
|
||||
<div className='flex items-center mb-2'>
|
||||
<div
|
||||
className='flex items-center mb-2'
|
||||
onClick={handleApiConfigSecretClick}
|
||||
>
|
||||
<Avatar
|
||||
size='small'
|
||||
color='green'
|
||||
@@ -2094,7 +2275,7 @@ const EditChannelModal = (props) => {
|
||||
inputs.type !== 8 &&
|
||||
inputs.type !== 22 &&
|
||||
inputs.type !== 36 &&
|
||||
inputs.type !== 45 && (
|
||||
(inputs.type !== 45 || doubaoApiEditUnlocked) && (
|
||||
<div>
|
||||
<Form.Input
|
||||
field='base_url'
|
||||
@@ -2147,7 +2328,7 @@ const EditChannelModal = (props) => {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{inputs.type === 45 && (
|
||||
{inputs.type === 45 && !doubaoApiEditUnlocked && (
|
||||
<div>
|
||||
<Form.Select
|
||||
field='base_url'
|
||||
@@ -2167,6 +2348,10 @@ const EditChannelModal = (props) => {
|
||||
label:
|
||||
'https://ark.ap-southeast.bytepluses.com',
|
||||
},
|
||||
{
|
||||
value: 'doubao-coding-plan',
|
||||
label: 'Doubao Coding Plan',
|
||||
},
|
||||
]}
|
||||
defaultValue='https://ark.cn-beijing.volces.com'
|
||||
/>
|
||||
@@ -2883,6 +3068,7 @@ const EditChannelModal = (props) => {
|
||||
visible={modelModalVisible}
|
||||
models={fetchedModels}
|
||||
selected={inputs.models}
|
||||
redirectModels={redirectModelList}
|
||||
onConfirm={(selectedModels) => {
|
||||
handleInputChange('models', selectedModels);
|
||||
showSuccess(t('模型列表已更新'));
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import React, { useState, useEffect, useMemo } from 'react';
|
||||
import { useIsMobile } from '../../../../hooks/common/useIsMobile';
|
||||
import {
|
||||
Modal,
|
||||
@@ -28,12 +28,13 @@ import {
|
||||
Empty,
|
||||
Tabs,
|
||||
Collapse,
|
||||
Tooltip,
|
||||
} from '@douyinfe/semi-ui';
|
||||
import {
|
||||
IllustrationNoResult,
|
||||
IllustrationNoResultDark,
|
||||
} from '@douyinfe/semi-illustrations';
|
||||
import { IconSearch } from '@douyinfe/semi-icons';
|
||||
import { IconSearch, IconInfoCircle } from '@douyinfe/semi-icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { getModelCategories } from '../../../../helpers/render';
|
||||
|
||||
@@ -41,6 +42,7 @@ const ModelSelectModal = ({
|
||||
visible,
|
||||
models = [],
|
||||
selected = [],
|
||||
redirectModels = [],
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}) => {
|
||||
@@ -50,15 +52,54 @@ const ModelSelectModal = ({
|
||||
const [activeTab, setActiveTab] = useState('new');
|
||||
|
||||
const isMobile = useIsMobile();
|
||||
const normalizeModelName = (model) =>
|
||||
typeof model === 'string' ? model.trim() : '';
|
||||
const normalizedRedirectModels = useMemo(
|
||||
() =>
|
||||
Array.from(
|
||||
new Set(
|
||||
(redirectModels || [])
|
||||
.map((model) => normalizeModelName(model))
|
||||
.filter(Boolean),
|
||||
),
|
||||
),
|
||||
[redirectModels],
|
||||
);
|
||||
const normalizedSelectedSet = useMemo(() => {
|
||||
const set = new Set();
|
||||
(selected || []).forEach((model) => {
|
||||
const normalized = normalizeModelName(model);
|
||||
if (normalized) {
|
||||
set.add(normalized);
|
||||
}
|
||||
});
|
||||
return set;
|
||||
}, [selected]);
|
||||
const classificationSet = useMemo(() => {
|
||||
const set = new Set(normalizedSelectedSet);
|
||||
normalizedRedirectModels.forEach((model) => set.add(model));
|
||||
return set;
|
||||
}, [normalizedSelectedSet, normalizedRedirectModels]);
|
||||
const redirectOnlySet = useMemo(() => {
|
||||
const set = new Set();
|
||||
normalizedRedirectModels.forEach((model) => {
|
||||
if (!normalizedSelectedSet.has(model)) {
|
||||
set.add(model);
|
||||
}
|
||||
});
|
||||
return set;
|
||||
}, [normalizedRedirectModels, normalizedSelectedSet]);
|
||||
|
||||
const filteredModels = models.filter((m) =>
|
||||
m.toLowerCase().includes(keyword.toLowerCase()),
|
||||
String(m || '').toLowerCase().includes(keyword.toLowerCase()),
|
||||
);
|
||||
|
||||
// 分类模型:新获取的模型和已有模型
|
||||
const newModels = filteredModels.filter((model) => !selected.includes(model));
|
||||
const isExistingModel = (model) =>
|
||||
classificationSet.has(normalizeModelName(model));
|
||||
const newModels = filteredModels.filter((model) => !isExistingModel(model));
|
||||
const existingModels = filteredModels.filter((model) =>
|
||||
selected.includes(model),
|
||||
isExistingModel(model),
|
||||
);
|
||||
|
||||
// 同步外部选中值
|
||||
@@ -228,7 +269,20 @@ const ModelSelectModal = ({
|
||||
<div className='grid grid-cols-2 gap-x-4'>
|
||||
{categoryData.models.map((model) => (
|
||||
<Checkbox key={model} value={model} className='my-1'>
|
||||
{model}
|
||||
<span className='flex items-center gap-2'>
|
||||
<span>{model}</span>
|
||||
{redirectOnlySet.has(normalizeModelName(model)) && (
|
||||
<Tooltip
|
||||
position='top'
|
||||
content={t('来自模型重定向,尚未加入模型列表')}
|
||||
>
|
||||
<IconInfoCircle
|
||||
size='small'
|
||||
className='text-amber-500 cursor-help'
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
</span>
|
||||
</Checkbox>
|
||||
))}
|
||||
</div>
|
||||
|
||||
@@ -72,6 +72,7 @@ const EditUserModal = (props) => {
|
||||
password: '',
|
||||
github_id: '',
|
||||
oidc_id: '',
|
||||
discord_id: '',
|
||||
wechat_id: '',
|
||||
telegram_id: '',
|
||||
email: '',
|
||||
@@ -332,6 +333,7 @@ const EditUserModal = (props) => {
|
||||
<Row gutter={12}>
|
||||
{[
|
||||
'github_id',
|
||||
'discord_id',
|
||||
'oidc_id',
|
||||
'wechat_id',
|
||||
'email',
|
||||
|
||||
@@ -231,6 +231,17 @@ export async function getOAuthState() {
|
||||
}
|
||||
}
|
||||
|
||||
export async function onDiscordOAuthClicked(client_id) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
const redirect_uri = `${window.location.origin}/oauth/discord`;
|
||||
const response_type = 'code';
|
||||
const scope = 'identify+openid';
|
||||
window.open(
|
||||
`https://discord.com/oauth2/authorize?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`,
|
||||
);
|
||||
}
|
||||
|
||||
export async function onOIDCClicked(auth_url, client_id, openInNewTab = false) {
|
||||
const state = await getOAuthState();
|
||||
if (!state) return;
|
||||
|
||||
@@ -145,8 +145,9 @@ export const getModelCategories = (() => {
|
||||
model.model_name.toLowerCase().includes('gpt') ||
|
||||
model.model_name.toLowerCase().includes('dall-e') ||
|
||||
model.model_name.toLowerCase().includes('whisper') ||
|
||||
model.model_name.toLowerCase().includes('tts') ||
|
||||
model.model_name.toLowerCase().includes('text-') ||
|
||||
model.model_name.toLowerCase().includes('tts-1') ||
|
||||
model.model_name.toLowerCase().includes('text-embedding-3') ||
|
||||
model.model_name.toLowerCase().includes('text-moderation') ||
|
||||
model.model_name.toLowerCase().includes('babbage') ||
|
||||
model.model_name.toLowerCase().includes('davinci') ||
|
||||
model.model_name.toLowerCase().includes('curie') ||
|
||||
@@ -163,19 +164,31 @@ export const getModelCategories = (() => {
|
||||
gemini: {
|
||||
label: 'Gemini',
|
||||
icon: <Gemini.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('gemini'),
|
||||
filter: (model) =>
|
||||
model.model_name.toLowerCase().includes('gemini') ||
|
||||
model.model_name.toLowerCase().includes('gemma') ||
|
||||
model.model_name.toLowerCase().includes('learnlm') ||
|
||||
model.model_name.toLowerCase().startsWith('embedding-') ||
|
||||
model.model_name.toLowerCase().includes('text-embedding-004') ||
|
||||
model.model_name.toLowerCase().includes('imagen-4') ||
|
||||
model.model_name.toLowerCase().includes('veo-') ||
|
||||
model.model_name.toLowerCase().includes('aqa') ,
|
||||
},
|
||||
moonshot: {
|
||||
label: 'Moonshot',
|
||||
icon: <Moonshot />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('moonshot'),
|
||||
filter: (model) =>
|
||||
model.model_name.toLowerCase().includes('moonshot') ||
|
||||
model.model_name.toLowerCase().includes('kimi'),
|
||||
},
|
||||
zhipu: {
|
||||
label: t('智谱'),
|
||||
icon: <Zhipu.Color />,
|
||||
filter: (model) =>
|
||||
model.model_name.toLowerCase().includes('chatglm') ||
|
||||
model.model_name.toLowerCase().includes('glm-'),
|
||||
model.model_name.toLowerCase().includes('glm-') ||
|
||||
model.model_name.toLowerCase().includes('cogview') ||
|
||||
model.model_name.toLowerCase().includes('cogvideo'),
|
||||
},
|
||||
qwen: {
|
||||
label: t('通义千问'),
|
||||
@@ -190,7 +203,9 @@ export const getModelCategories = (() => {
|
||||
minimax: {
|
||||
label: 'MiniMax',
|
||||
icon: <Minimax.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('abab'),
|
||||
filter: (model) =>
|
||||
model.model_name.toLowerCase().includes('abab') ||
|
||||
model.model_name.toLowerCase().includes('minimax'),
|
||||
},
|
||||
baidu: {
|
||||
label: t('文心一言'),
|
||||
@@ -215,7 +230,10 @@ export const getModelCategories = (() => {
|
||||
cohere: {
|
||||
label: 'Cohere',
|
||||
icon: <Cohere.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('command'),
|
||||
filter: (model) =>
|
||||
model.model_name.toLowerCase().includes('command') ||
|
||||
model.model_name.toLowerCase().includes('c4ai-') ||
|
||||
model.model_name.toLowerCase().includes('embed-'),
|
||||
},
|
||||
cloudflare: {
|
||||
label: 'Cloudflare',
|
||||
@@ -227,11 +245,6 @@ export const getModelCategories = (() => {
|
||||
icon: <Ai360.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('360'),
|
||||
},
|
||||
yi: {
|
||||
label: t('零一万物'),
|
||||
icon: <Yi.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('yi'),
|
||||
},
|
||||
jina: {
|
||||
label: 'Jina',
|
||||
icon: <Jina />,
|
||||
@@ -240,7 +253,12 @@ export const getModelCategories = (() => {
|
||||
mistral: {
|
||||
label: 'Mistral AI',
|
||||
icon: <Mistral.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('mistral'),
|
||||
filter: (model) =>
|
||||
model.model_name.toLowerCase().includes('mistral') ||
|
||||
model.model_name.toLowerCase().includes('codestral') ||
|
||||
model.model_name.toLowerCase().includes('pixtral') ||
|
||||
model.model_name.toLowerCase().includes('voxtral') ||
|
||||
model.model_name.toLowerCase().includes('magistral'),
|
||||
},
|
||||
xai: {
|
||||
label: 'xAI',
|
||||
@@ -257,6 +275,11 @@ export const getModelCategories = (() => {
|
||||
icon: <Doubao.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('doubao'),
|
||||
},
|
||||
yi: {
|
||||
label: t('零一万物'),
|
||||
icon: <Yi.Color />,
|
||||
filter: (model) => model.model_name.toLowerCase().includes('yi'),
|
||||
},
|
||||
};
|
||||
|
||||
lastLocale = currentLocale;
|
||||
@@ -1772,10 +1795,13 @@ export function renderClaudeModelPrice(
|
||||
|
||||
// Calculate effective input tokens (non-cached + cached with ratio applied + cache creation with ratio applied)
|
||||
const nonCachedTokens = inputTokens;
|
||||
const legacyCacheCreationTokens = hasSplitCacheCreation
|
||||
? 0
|
||||
: cacheCreationTokens;
|
||||
const effectiveInputTokens =
|
||||
nonCachedTokens +
|
||||
cacheTokens * cacheRatio +
|
||||
cacheCreationTokens * cacheCreationRatio +
|
||||
legacyCacheCreationTokens * cacheCreationRatio +
|
||||
cacheCreationTokens5m * cacheCreationRatio5m +
|
||||
cacheCreationTokens1h * cacheCreationRatio1h;
|
||||
|
||||
|
||||
@@ -229,7 +229,7 @@ export const useApiRequest = (
|
||||
if (data.choices?.[0]) {
|
||||
const choice = data.choices[0];
|
||||
let content = choice.message?.content || '';
|
||||
let reasoningContent = choice.message?.reasoning_content || '';
|
||||
let reasoningContent = choice.message?.reasoning_content || choice.message?.reasoning || '';
|
||||
|
||||
const processed = processThinkTags(content, reasoningContent);
|
||||
|
||||
@@ -333,6 +333,9 @@ export const useApiRequest = (
|
||||
if (delta.reasoning_content) {
|
||||
streamMessageUpdate(delta.reasoning_content, 'reasoning');
|
||||
}
|
||||
if (delta.reasoning) {
|
||||
streamMessageUpdate(delta.reasoning, 'reasoning');
|
||||
}
|
||||
if (delta.content) {
|
||||
streamMessageUpdate(delta.content, 'content');
|
||||
}
|
||||
|
||||
@@ -482,6 +482,18 @@ export const useLogsData = () => {
|
||||
value: other.request_path,
|
||||
});
|
||||
}
|
||||
if (isAdminUser) {
|
||||
let localCountMode = '';
|
||||
if (other?.admin_info?.local_count_tokens) {
|
||||
localCountMode = t('本地计费');
|
||||
} else {
|
||||
localCountMode = t('上游返回');
|
||||
}
|
||||
expandDataLocal.push({
|
||||
key: t('计费模式'),
|
||||
value: localCountMode,
|
||||
});
|
||||
}
|
||||
expandDatesLocal[logs[i].key] = expandDataLocal;
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import frTranslation from './locales/fr.json';
|
||||
import zhTranslation from './locales/zh.json';
|
||||
import ruTranslation from './locales/ru.json';
|
||||
import jaTranslation from './locales/ja.json';
|
||||
import viTranslation from './locales/vi.json';
|
||||
|
||||
i18n
|
||||
.use(LanguageDetector)
|
||||
@@ -38,6 +39,7 @@ i18n
|
||||
fr: frTranslation,
|
||||
ru: ruTranslation,
|
||||
ja: jaTranslation,
|
||||
vi: viTranslation,
|
||||
},
|
||||
fallbackLng: 'zh',
|
||||
interpolation: {
|
||||
|
||||
@@ -69,6 +69,8 @@
|
||||
"Gemini思考适配设置": "Gemini thinking adaptation settings",
|
||||
"Gemini版本设置": "Gemini version settings",
|
||||
"Gemini设置": "Gemini settings",
|
||||
"启用FunctionCall思维签名填充": "Enable FunctionCall thoughtSignature fill",
|
||||
"仅为使用OpenAI格式的Gemini/Vertex渠道填充thoughtSignature": "Fill thoughtSignature only for Gemini/Vertex channels using the OpenAI format",
|
||||
"GitHub": "GitHub",
|
||||
"GitHub Client ID": "GitHub Client ID",
|
||||
"GitHub Client Secret": "GitHub Client Secret",
|
||||
@@ -2109,6 +2111,8 @@
|
||||
"请填写完整的产品信息": "Please fill in complete product information",
|
||||
"产品ID已存在": "Product ID already exists",
|
||||
"统一的": "The Unified",
|
||||
"大模型接口网关": "LLM API Gateway"
|
||||
"大模型接口网关": "LLM API Gateway",
|
||||
"正在跳转 GitHub...": "Redirecting to GitHub...",
|
||||
"请求超时,请刷新页面后重新发起 GitHub 登录": "Request timed out, please refresh and restart GitHub login"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +71,8 @@
|
||||
"Gemini思考适配设置": "Paramètres d'adaptation de la pensée Gemini",
|
||||
"Gemini版本设置": "Paramètres de version Gemini",
|
||||
"Gemini设置": "Paramètres Gemini",
|
||||
"启用FunctionCall思维签名填充": "Activer le remplissage de thoughtSignature pour FunctionCall",
|
||||
"仅为使用OpenAI格式的Gemini/Vertex渠道填充thoughtSignature": "Remplit thoughtSignature uniquement pour les canaux Gemini/Vertex utilisant le format OpenAI",
|
||||
"GitHub": "GitHub",
|
||||
"GitHub Client ID": "ID client GitHub",
|
||||
"GitHub Client Secret": "Secret client GitHub",
|
||||
@@ -2089,6 +2091,8 @@
|
||||
"默认测试模型": "Modèle de test par défaut",
|
||||
"默认补全倍率": "Taux de complétion par défaut",
|
||||
"统一的": "La Passerelle",
|
||||
"大模型接口网关": "API LLM Unifiée"
|
||||
"大模型接口网关": "API LLM Unifiée",
|
||||
"正在跳转 GitHub...": "Redirection vers GitHub...",
|
||||
"请求超时,请刷新页面后重新发起 GitHub 登录": "Délai dépassé, veuillez actualiser la page puis relancer la connexion GitHub"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +69,8 @@
|
||||
"Gemini思考适配设置": "Gemini思考モード設定",
|
||||
"Gemini版本设置": "Geminiバージョン設定",
|
||||
"Gemini设置": "Gemini設定",
|
||||
"启用FunctionCall思维签名填充": "FunctionCall用のthoughtSignature自動付与を有効化",
|
||||
"仅为使用OpenAI格式的Gemini/Vertex渠道填充thoughtSignature": "OpenAI形式を利用するGemini/VertexチャネルにのみthoughtSignatureを付与します",
|
||||
"GitHub": "GitHub",
|
||||
"GitHub Client ID": "GitHub Client ID",
|
||||
"GitHub Client Secret": "GitHub Client Secret",
|
||||
@@ -2080,6 +2082,8 @@
|
||||
"默认测试模型": "デフォルトテストモデル",
|
||||
"默认补全倍率": "デフォルト補完倍率",
|
||||
"统一的": "統合型",
|
||||
"大模型接口网关": "LLM APIゲートウェイ"
|
||||
"大模型接口网关": "LLM APIゲートウェイ",
|
||||
"正在跳转 GitHub...": "GitHub にリダイレクトしています...",
|
||||
"请求超时,请刷新页面后重新发起 GitHub 登录": "タイムアウトしました。ページをリロードして GitHub ログインをやり直してください"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,6 +73,8 @@
|
||||
"Gemini思考适配设置": "Настройки адаптации мышления Gemini",
|
||||
"Gemini版本设置": "Настройки версии Gemini",
|
||||
"Gemini设置": "Настройки Gemini",
|
||||
"启用FunctionCall思维签名填充": "Включить автозаполнение thoughtSignature для FunctionCall",
|
||||
"仅为使用OpenAI格式的Gemini/Vertex渠道填充thoughtSignature": "Заполнять thoughtSignature только для каналов Gemini/Vertex, использующих формат OpenAI",
|
||||
"GitHub": "GitHub",
|
||||
"GitHub Client ID": "ID клиента GitHub",
|
||||
"GitHub Client Secret": "Секрет клиента GitHub",
|
||||
@@ -2098,6 +2100,8 @@
|
||||
"默认测试模型": "Модель для тестирования по умолчанию",
|
||||
"默认补全倍率": "Коэффициент вывода по умолчанию",
|
||||
"统一的": "Единый",
|
||||
"大模型接口网关": "Шлюз API LLM"
|
||||
"大模型接口网关": "Шлюз API LLM",
|
||||
"正在跳转 GitHub...": "Перенаправление на GitHub...",
|
||||
"请求超时,请刷新页面后重新发起 GitHub 登录": "Время ожидания истекло, обновите страницу и снова запустите вход через GitHub"
|
||||
}
|
||||
}
|
||||
|
||||
2700
web/src/i18n/locales/vi.json
Normal file
2700
web/src/i18n/locales/vi.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -67,6 +67,8 @@
|
||||
"Gemini思考适配设置": "Gemini思考适配设置",
|
||||
"Gemini版本设置": "Gemini版本设置",
|
||||
"Gemini设置": "Gemini设置",
|
||||
"启用FunctionCall思维签名填充": "启用FunctionCall思维签名填充",
|
||||
"仅为使用OpenAI格式的Gemini/Vertex渠道填充thoughtSignature": "仅为使用OpenAI格式的Gemini/Vertex渠道填充thoughtSignature",
|
||||
"GitHub": "GitHub",
|
||||
"GitHub Client ID": "GitHub Client ID",
|
||||
"GitHub Client Secret": "GitHub Client Secret",
|
||||
@@ -255,6 +257,7 @@
|
||||
"余额充值管理": "余额充值管理",
|
||||
"你似乎并没有修改什么": "你似乎并没有修改什么",
|
||||
"使用 GitHub 继续": "使用 GitHub 继续",
|
||||
"使用 Discord 继续": "使用 Discord 继续",
|
||||
"使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}": "使用 JSON 对象格式,格式为:{\"组名\": [最多请求次数, 最多请求完成次数]}",
|
||||
"使用 LinuxDO 继续": "使用 LinuxDO 继续",
|
||||
"使用 OIDC 继续": "使用 OIDC 继续",
|
||||
@@ -2071,6 +2074,8 @@
|
||||
"默认测试模型": "默认测试模型",
|
||||
"默认补全倍率": "默认补全倍率",
|
||||
"Creem 介绍": "Creem 是一个简单的支付处理平台,支持固定金额产品销售,以及订阅销售。",
|
||||
"Creem Setting Tips": "Creem 只支持预设的固定金额产品,这产品以及价格需要提前在Creem网站内创建配置,所以不支持自定义动态金额充值。在Creem端配置产品的名字以及价格,获取Product Id 后填到下面的产品,在new-api为该产品设置充值额度,以及展示价格。"
|
||||
"Creem Setting Tips": "Creem 只支持预设的固定金额产品,这产品以及价格需要提前在Creem网站内创建配置,所以不支持自定义动态金额充值。在Creem端配置产品的名字以及价格,获取Product Id 后填到下面的产品,在new-api为该产品设置充值额度,以及展示价格。",
|
||||
"正在跳转 GitHub...": "正在跳转 GitHub...",
|
||||
"请求超时,请刷新页面后重新发起 GitHub 登录": "请求超时,请刷新页面后重新发起 GitHub 登录"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,19 +39,22 @@ const GEMINI_VERSION_EXAMPLE = {
|
||||
default: 'v1beta',
|
||||
};
|
||||
|
||||
const DEFAULT_GEMINI_INPUTS = {
|
||||
'gemini.safety_settings': '',
|
||||
'gemini.version_settings': '',
|
||||
'gemini.supported_imagine_models': '',
|
||||
'gemini.thinking_adapter_enabled': false,
|
||||
'gemini.thinking_adapter_budget_tokens_percentage': 0.6,
|
||||
'gemini.function_call_thought_signature_enabled': true,
|
||||
};
|
||||
|
||||
export default function SettingGeminiModel(props) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [inputs, setInputs] = useState({
|
||||
'gemini.safety_settings': '',
|
||||
'gemini.version_settings': '',
|
||||
'gemini.supported_imagine_models': '',
|
||||
'gemini.thinking_adapter_enabled': false,
|
||||
'gemini.thinking_adapter_budget_tokens_percentage': 0.6,
|
||||
});
|
||||
const [inputs, setInputs] = useState(DEFAULT_GEMINI_INPUTS);
|
||||
const refForm = useRef();
|
||||
const [inputsRow, setInputsRow] = useState(inputs);
|
||||
const [inputsRow, setInputsRow] = useState(DEFAULT_GEMINI_INPUTS);
|
||||
|
||||
async function onSubmit() {
|
||||
await refForm.current
|
||||
@@ -92,9 +95,9 @@ export default function SettingGeminiModel(props) {
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const currentInputs = {};
|
||||
const currentInputs = { ...DEFAULT_GEMINI_INPUTS };
|
||||
for (let key in props.options) {
|
||||
if (Object.keys(inputs).includes(key)) {
|
||||
if (Object.prototype.hasOwnProperty.call(DEFAULT_GEMINI_INPUTS, key)) {
|
||||
currentInputs[key] = props.options[key];
|
||||
}
|
||||
}
|
||||
@@ -166,6 +169,23 @@ export default function SettingGeminiModel(props) {
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row>
|
||||
<Col span={16}>
|
||||
<Form.Switch
|
||||
label={t('启用FunctionCall思维签名填充')}
|
||||
field={'gemini.function_call_thought_signature_enabled'}
|
||||
extraText={t(
|
||||
'仅为使用OpenAI格式的Gemini/Vertex渠道填充thoughtSignature',
|
||||
)}
|
||||
onChange={(value) =>
|
||||
setInputs({
|
||||
...inputs,
|
||||
'gemini.function_call_thought_signature_enabled': value,
|
||||
})
|
||||
}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row>
|
||||
<Col xs={24} sm={12} md={8} lg={8} xl={8}>
|
||||
<Form.TextArea
|
||||
|
||||
Reference in New Issue
Block a user