mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
Merge branch 'alpha' into imageratio-and-audioratio-edit
This commit is contained in:
@@ -56,8 +56,6 @@
|
||||
# SESSION_SECRET=random_string
|
||||
|
||||
# 其他配置
|
||||
# 渠道测试频率(单位:秒)
|
||||
# CHANNEL_TEST_FREQUENCY=10
|
||||
# 生成默认token
|
||||
# GENERATE_DEFAULT_TOKEN=false
|
||||
# Cohere 安全设置
|
||||
|
||||
21
.github/workflows/pr-target-branch-check.yml
vendored
21
.github/workflows/pr-target-branch-check.yml
vendored
@@ -1,21 +0,0 @@
|
||||
name: Check PR Branching Strategy
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, edited]
|
||||
|
||||
jobs:
|
||||
check-branching-strategy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Enforce branching strategy
|
||||
run: |
|
||||
if [[ "${{ github.base_ref }}" == "main" ]]; then
|
||||
if [[ "${{ github.head_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests to 'main' are only allowed from the 'alpha' branch."
|
||||
exit 1
|
||||
fi
|
||||
elif [[ "${{ github.base_ref }}" != "alpha" ]]; then
|
||||
echo "Error: Pull requests must be targeted to the 'alpha' or 'main' branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "Branching strategy check passed."
|
||||
@@ -96,7 +96,11 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do
|
||||
- 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`)
|
||||
16. 🔄 思考转内容功能
|
||||
17. 🔄 针对用户的模型限流功能
|
||||
18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||
18. 🔄 请求格式转换功能,支持以下三种格式转换:
|
||||
1. OpenAI Chat Completions => Claude Messages
|
||||
2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型)
|
||||
3. OpenAI Chat Completions => Gemini Chat
|
||||
19. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费:
|
||||
1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项
|
||||
2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费
|
||||
3. 支持的渠道:
|
||||
|
||||
@@ -12,4 +12,4 @@ var LogSqlType = DatabaseTypeSQLite // Default to SQLite for logging SQL queries
|
||||
var UsingMySQL = false
|
||||
var UsingClickHouse = false
|
||||
|
||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||
var SQLitePath = "one-api.db?_busy_timeout=30000"
|
||||
@@ -123,8 +123,16 @@ func Interface2String(inter interface{}) string {
|
||||
return fmt.Sprintf("%d", inter.(int))
|
||||
case float64:
|
||||
return fmt.Sprintf("%f", inter.(float64))
|
||||
case bool:
|
||||
if inter.(bool) {
|
||||
return "true"
|
||||
} else {
|
||||
return "false"
|
||||
}
|
||||
case nil:
|
||||
return ""
|
||||
}
|
||||
return "Not Implemented"
|
||||
return fmt.Sprintf("%v", inter)
|
||||
}
|
||||
|
||||
func UnescapeHTML(x string) interface{} {
|
||||
@@ -257,32 +265,32 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration")
|
||||
}
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
durationStr := string(bytes.TrimSpace(output))
|
||||
if durationStr == "N/A" {
|
||||
// Create a temporary output file name
|
||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to create temporary file")
|
||||
}
|
||||
tmpName := tmpFp.Name()
|
||||
// Close immediately so ffmpeg can open the file on Windows.
|
||||
_ = tmpFp.Close()
|
||||
defer os.Remove(tmpName)
|
||||
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
// ffmpeg -y -i filename -vcodec copy -acodec copy <tmpName>
|
||||
ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName)
|
||||
if err := ffmpegCmd.Run(); err != nil {
|
||||
return 0, errors.Wrap(err, "failed to run ffmpeg")
|
||||
}
|
||||
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
// Recalculate the duration of the new file
|
||||
c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName)
|
||||
output, err := c.Output()
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg")
|
||||
}
|
||||
durationStr = string(bytes.TrimSpace(output))
|
||||
}
|
||||
return strconv.ParseFloat(durationStr, 64)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -342,7 +342,7 @@ func updateChannelMoonshotBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, fmt.Errorf("failed to update moonshot balance, status: %v, code: %d, scode: %s", response.Status, response.Code, response.Scode)
|
||||
}
|
||||
availableBalanceCny := response.Data.AvailableBalance
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(setting.Price)).InexactFloat64()
|
||||
availableBalanceUsd := decimal.NewFromFloat(availableBalanceCny).Div(decimal.NewFromFloat(operation_setting.Price)).InexactFloat64()
|
||||
channel.UpdateBalance(availableBalanceUsd)
|
||||
return availableBalanceUsd, nil
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -234,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err := service.RelayErrorHandler(httpResp, true)
|
||||
err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
||||
return testResult{
|
||||
context: c,
|
||||
localErr: err,
|
||||
@@ -477,15 +478,26 @@ func TestAllChannels(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func AutomaticallyTestChannels(frequency int) {
|
||||
if frequency <= 0 {
|
||||
common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test")
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("testing all channels")
|
||||
_ = testAllChannels(false)
|
||||
common.SysLog("channel test finished")
|
||||
}
|
||||
var autoTestChannelsOnce sync.Once
|
||||
|
||||
func AutomaticallyTestChannels() {
|
||||
autoTestChannelsOnce.Do(func() {
|
||||
for {
|
||||
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||
time.Sleep(10 * time.Minute)
|
||||
continue
|
||||
}
|
||||
frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
|
||||
common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency))
|
||||
for {
|
||||
time.Sleep(time.Duration(frequency) * time.Minute)
|
||||
common.SysLog("automatically testing all channels")
|
||||
_ = testAllChannels(false)
|
||||
common.SysLog("automatically channel test finished")
|
||||
if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -560,7 +561,7 @@ func AddChannel(c *gin.Context) {
|
||||
case "multi_to_single":
|
||||
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
|
||||
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -585,7 +586,7 @@ func AddChannel(c *gin.Context) {
|
||||
}
|
||||
keys = []string{addChannelRequest.Channel.Key}
|
||||
case "batch":
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi {
|
||||
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
// multi json
|
||||
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
|
||||
if err != nil {
|
||||
@@ -840,7 +841,7 @@ func UpdateChannel(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 处理 Vertex AI 的特殊情况
|
||||
if channel.Type == constant.ChannelTypeVertexAi {
|
||||
if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
// 尝试解析新密钥为JSON数组
|
||||
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
|
||||
array, err := getVertexArrayKeys(channel.Key)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) {
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
@@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) {
|
||||
|
||||
if setting.MjForwardUrlEnabled {
|
||||
for i, midjourney := range items {
|
||||
midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||
items[i] = midjourney
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +39,8 @@ func TestStatus(c *gin.Context) {
|
||||
func GetStatus(c *gin.Context) {
|
||||
|
||||
cs := console_setting.GetConsoleSetting()
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
data := gin.H{
|
||||
"version": common.Version,
|
||||
@@ -56,11 +58,7 @@ func GetStatus(c *gin.Context) {
|
||||
"footer_html": common.Footer,
|
||||
"wechat_qrcode": common.WeChatAccountQRCodeImageURL,
|
||||
"wechat_login": common.WeChatAuthEnabled,
|
||||
"server_address": setting.ServerAddress,
|
||||
"price": setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
"min_topup": setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"server_address": system_setting.ServerAddress,
|
||||
"turnstile_check": common.TurnstileCheckEnabled,
|
||||
"turnstile_site_key": common.TurnstileSiteKey,
|
||||
"top_up_link": common.TopUpLink,
|
||||
@@ -73,15 +71,15 @@ func GetStatus(c *gin.Context) {
|
||||
"enable_data_export": common.DataExportEnabled,
|
||||
"data_export_default_time": common.DataExportDefaultTime,
|
||||
"default_collapse_sidebar": common.DefaultCollapseSidebar,
|
||||
"enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"mj_notify_enabled": setting.MjNotifyEnabled,
|
||||
"chats": setting.Chats,
|
||||
"demo_site_enabled": operation_setting.DemoSiteEnabled,
|
||||
"self_use_mode_enabled": operation_setting.SelfUseModeEnabled,
|
||||
"default_use_auto_group": setting.DefaultUseAutoGroup,
|
||||
"pay_methods": setting.PayMethods,
|
||||
"usd_exchange_rate": setting.USDExchangeRate,
|
||||
|
||||
"usd_exchange_rate": operation_setting.USDExchangeRate,
|
||||
"price": operation_setting.Price,
|
||||
"stripe_unit_price": setting.StripeUnitPrice,
|
||||
|
||||
// 面板启用开关
|
||||
"api_info_enabled": cs.ApiInfoEnabled,
|
||||
@@ -89,6 +87,10 @@ func GetStatus(c *gin.Context) {
|
||||
"announcements_enabled": cs.AnnouncementsEnabled,
|
||||
"faq_enabled": cs.FAQEnabled,
|
||||
|
||||
// 模块管理配置
|
||||
"HeaderNavModules": common.OptionMap["HeaderNavModules"],
|
||||
"SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"],
|
||||
|
||||
"oidc_enabled": system_setting.GetOIDCSettings().Enabled,
|
||||
"oidc_client_id": system_setting.GetOIDCSettings().ClientId,
|
||||
"oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint,
|
||||
@@ -247,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) {
|
||||
}
|
||||
code := common.GenerateVerificationCode(0)
|
||||
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
|
||||
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code)
|
||||
subject := fmt.Sprintf("%s密码重置", common.SystemName)
|
||||
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
|
||||
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
|
||||
|
||||
@@ -207,6 +207,7 @@ func ListModels(c *gin.Context, modelType int) {
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"data": userOpenAiModels,
|
||||
"object": "list",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
604
controller/model_sync.go
Normal file
604
controller/model_sync.go
Normal file
@@ -0,0 +1,604 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 上游地址
|
||||
const (
|
||||
upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json"
|
||||
upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json"
|
||||
)
|
||||
|
||||
func normalizeLocale(locale string) (string, bool) {
|
||||
l := strings.ToLower(strings.TrimSpace(locale))
|
||||
switch l {
|
||||
case "en", "zh", "ja":
|
||||
return l, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func getUpstreamBase() string {
|
||||
return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata")
|
||||
}
|
||||
|
||||
func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) {
|
||||
base := strings.TrimRight(getUpstreamBase(), "/")
|
||||
if l, ok := normalizeLocale(locale); ok && l != "" {
|
||||
return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l),
|
||||
fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l)
|
||||
}
|
||||
return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base)
|
||||
}
|
||||
|
||||
type upstreamEnvelope[T any] struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data []T `json:"data"`
|
||||
}
|
||||
|
||||
type upstreamModel struct {
|
||||
Description string `json:"description"`
|
||||
Endpoints json.RawMessage `json:"endpoints"`
|
||||
Icon string `json:"icon"`
|
||||
ModelName string `json:"model_name"`
|
||||
NameRule int `json:"name_rule"`
|
||||
Status int `json:"status"`
|
||||
Tags string `json:"tags"`
|
||||
VendorName string `json:"vendor_name"`
|
||||
}
|
||||
|
||||
type upstreamVendor struct {
|
||||
Description string `json:"description"`
|
||||
Icon string `json:"icon"`
|
||||
Name string `json:"name"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
var (
|
||||
etagCache = make(map[string]string)
|
||||
bodyCache = make(map[string][]byte)
|
||||
cacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
type overwriteField struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Fields []string `json:"fields"`
|
||||
}
|
||||
|
||||
type syncRequest struct {
|
||||
Overwrite []overwriteField `json:"overwrite"`
|
||||
Locale string `json:"locale"`
|
||||
}
|
||||
|
||||
func newHTTPClient() *http.Client {
|
||||
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10)
|
||||
dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second}
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second,
|
||||
}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
if strings.HasSuffix(host, "github.io") {
|
||||
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp6", addr)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
||||
|
||||
var httpClient = newHTTPClient()
|
||||
|
||||
func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error {
|
||||
var lastErr error
|
||||
attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3)
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
baseDelay := 200 * time.Millisecond
|
||||
maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10)
|
||||
maxBytes := int64(maxMB) << 20
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// ETag conditional request
|
||||
cacheMutex.RLock()
|
||||
if et := etagCache[url]; et != "" {
|
||||
req.Header.Set("If-None-Match", et)
|
||||
}
|
||||
cacheMutex.RUnlock()
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
// backoff with jitter
|
||||
sleep := baseDelay * time.Duration(1<<attempt)
|
||||
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||
time.Sleep(sleep + jitter)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer resp.Body.Close()
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
// read body into buffer for caching and flexible decode
|
||||
limited := io.LimitReader(resp.Body, maxBytes)
|
||||
buf, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
return
|
||||
}
|
||||
// cache body and ETag
|
||||
cacheMutex.Lock()
|
||||
if et := resp.Header.Get("ETag"); et != "" {
|
||||
etagCache[url] = et
|
||||
}
|
||||
bodyCache[url] = buf
|
||||
cacheMutex.Unlock()
|
||||
|
||||
// Try decode as envelope first
|
||||
if err := json.Unmarshal(buf, out); err != nil {
|
||||
// Try decode as pure array
|
||||
var arr []T
|
||||
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||
lastErr = err
|
||||
return
|
||||
}
|
||||
out.Success = true
|
||||
out.Data = arr
|
||||
out.Message = ""
|
||||
} else {
|
||||
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||
out.Success = true
|
||||
}
|
||||
}
|
||||
lastErr = nil
|
||||
case http.StatusNotModified:
|
||||
// use cache
|
||||
cacheMutex.RLock()
|
||||
buf := bodyCache[url]
|
||||
cacheMutex.RUnlock()
|
||||
if len(buf) == 0 {
|
||||
lastErr = errors.New("cache miss for 304 response")
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(buf, out); err != nil {
|
||||
var arr []T
|
||||
if err2 := json.Unmarshal(buf, &arr); err2 != nil {
|
||||
lastErr = err
|
||||
return
|
||||
}
|
||||
out.Success = true
|
||||
out.Data = arr
|
||||
out.Message = ""
|
||||
} else {
|
||||
if !out.Success && len(out.Data) == 0 && out.Message == "" {
|
||||
out.Success = true
|
||||
}
|
||||
}
|
||||
lastErr = nil
|
||||
default:
|
||||
lastErr = errors.New(resp.Status)
|
||||
}
|
||||
}()
|
||||
if lastErr == nil {
|
||||
return nil
|
||||
}
|
||||
sleep := baseDelay * time.Duration(1<<attempt)
|
||||
jitter := time.Duration(rand.Intn(150)) * time.Millisecond
|
||||
time.Sleep(sleep + jitter)
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func ensureVendorID(vendorName string, vendorByName map[string]upstreamVendor, vendorIDCache map[string]int, createdVendors *int) int {
|
||||
if vendorName == "" {
|
||||
return 0
|
||||
}
|
||||
if id, ok := vendorIDCache[vendorName]; ok {
|
||||
return id
|
||||
}
|
||||
var existing model.Vendor
|
||||
if err := model.DB.Where("name = ?", vendorName).First(&existing).Error; err == nil {
|
||||
vendorIDCache[vendorName] = existing.Id
|
||||
return existing.Id
|
||||
}
|
||||
uv := vendorByName[vendorName]
|
||||
v := &model.Vendor{
|
||||
Name: vendorName,
|
||||
Description: uv.Description,
|
||||
Icon: coalesce(uv.Icon, ""),
|
||||
Status: chooseStatus(uv.Status, 1),
|
||||
}
|
||||
if err := v.Insert(); err == nil {
|
||||
*createdVendors++
|
||||
vendorIDCache[vendorName] = v.Id
|
||||
return v.Id
|
||||
}
|
||||
vendorIDCache[vendorName] = 0
|
||||
return 0
|
||||
}
|
||||
|
||||
// SyncUpstreamModels 同步上游模型与供应商,仅对「未配置模型」生效
|
||||
func SyncUpstreamModels(c *gin.Context) {
|
||||
var req syncRequest
|
||||
// 允许空体
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
// 1) 获取未配置模型列表
|
||||
missing, err := model.GetMissingModels()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
if len(missing) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": gin.H{
|
||||
"created_models": 0,
|
||||
"created_vendors": 0,
|
||||
"skipped_models": []string{},
|
||||
}})
|
||||
return
|
||||
}
|
||||
|
||||
// 2) 拉取上游 vendors 与 models
|
||||
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
modelsURL, vendorsURL := getUpstreamURLs(req.Locale)
|
||||
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||
var fetchErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// vendor 失败不拦截
|
||||
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||
fetchErr = err
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
if fetchErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": req.Locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||
return
|
||||
}
|
||||
|
||||
// 建立映射
|
||||
vendorByName := make(map[string]upstreamVendor)
|
||||
for _, v := range vendorsEnv.Data {
|
||||
if v.Name != "" {
|
||||
vendorByName[v.Name] = v
|
||||
}
|
||||
}
|
||||
modelByName := make(map[string]upstreamModel)
|
||||
for _, m := range modelsEnv.Data {
|
||||
if m.ModelName != "" {
|
||||
modelByName[m.ModelName] = m
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 执行同步:仅创建缺失模型;若上游缺失该模型则跳过
|
||||
createdModels := 0
|
||||
createdVendors := 0
|
||||
updatedModels := 0
|
||||
var skipped []string
|
||||
var createdList []string
|
||||
var updatedList []string
|
||||
|
||||
// 本地缓存:vendorName -> id
|
||||
vendorIDCache := make(map[string]int)
|
||||
|
||||
for _, name := range missing {
|
||||
up, ok := modelByName[name]
|
||||
if !ok {
|
||||
skipped = append(skipped, name)
|
||||
continue
|
||||
}
|
||||
|
||||
// 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时)
|
||||
var existing model.Model
|
||||
if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil {
|
||||
if existing.SyncOfficial == 0 {
|
||||
skipped = append(skipped, name)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 vendor 存在
|
||||
vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||
|
||||
// 创建模型
|
||||
mi := &model.Model{
|
||||
ModelName: name,
|
||||
Description: up.Description,
|
||||
Icon: up.Icon,
|
||||
Tags: up.Tags,
|
||||
VendorID: vendorID,
|
||||
Status: chooseStatus(up.Status, 1),
|
||||
NameRule: up.NameRule,
|
||||
}
|
||||
if err := mi.Insert(); err == nil {
|
||||
createdModels++
|
||||
createdList = append(createdList, name)
|
||||
} else {
|
||||
skipped = append(skipped, name)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 处理可选覆盖(更新本地已有模型的差异字段)
|
||||
if len(req.Overwrite) > 0 {
|
||||
// vendorIDCache 已用于创建阶段,可复用
|
||||
for _, ow := range req.Overwrite {
|
||||
up, ok := modelByName[ow.ModelName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var local model.Model
|
||||
if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 跳过被禁用官方同步的模型
|
||||
if local.SyncOfficial == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 映射 vendor
|
||||
newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors)
|
||||
|
||||
// 应用字段覆盖(事务)
|
||||
_ = model.DB.Transaction(func(tx *gorm.DB) error {
|
||||
needUpdate := false
|
||||
if containsField(ow.Fields, "description") {
|
||||
local.Description = up.Description
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "icon") {
|
||||
local.Icon = up.Icon
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "tags") {
|
||||
local.Tags = up.Tags
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "vendor") {
|
||||
local.VendorID = newVendorID
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "name_rule") {
|
||||
local.NameRule = up.NameRule
|
||||
needUpdate = true
|
||||
}
|
||||
if containsField(ow.Fields, "status") {
|
||||
local.Status = chooseStatus(up.Status, local.Status)
|
||||
needUpdate = true
|
||||
}
|
||||
if !needUpdate {
|
||||
return nil
|
||||
}
|
||||
if err := tx.Save(&local).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
updatedModels++
|
||||
updatedList = append(updatedList, ow.ModelName)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"created_models": createdModels,
|
||||
"created_vendors": createdVendors,
|
||||
"updated_models": updatedModels,
|
||||
"skipped_models": skipped,
|
||||
"created_list": createdList,
|
||||
"updated_list": updatedList,
|
||||
"source": gin.H{
|
||||
"locale": req.Locale,
|
||||
"models_url": modelsURL,
|
||||
"vendors_url": vendorsURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func containsField(fields []string, key string) bool {
|
||||
key = strings.ToLower(strings.TrimSpace(key))
|
||||
for _, f := range fields {
|
||||
if strings.ToLower(strings.TrimSpace(f)) == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func coalesce(a, b string) string {
|
||||
if strings.TrimSpace(a) != "" {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func chooseStatus(primary, fallback int) int {
|
||||
if primary == 0 && fallback != 0 {
|
||||
return fallback
|
||||
}
|
||||
if primary != 0 {
|
||||
return primary
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择)
|
||||
func SyncUpstreamPreview(c *gin.Context) {
|
||||
// 1) 拉取上游数据
|
||||
timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
locale := c.Query("locale")
|
||||
modelsURL, vendorsURL := getUpstreamURLs(locale)
|
||||
|
||||
var vendorsEnv upstreamEnvelope[upstreamVendor]
|
||||
var modelsEnv upstreamEnvelope[upstreamModel]
|
||||
var fetchErr error
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = fetchJSON(ctx, vendorsURL, &vendorsEnv)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil {
|
||||
fetchErr = err
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
if fetchErr != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}})
|
||||
return
|
||||
}
|
||||
|
||||
vendorByName := make(map[string]upstreamVendor)
|
||||
for _, v := range vendorsEnv.Data {
|
||||
if v.Name != "" {
|
||||
vendorByName[v.Name] = v
|
||||
}
|
||||
}
|
||||
modelByName := make(map[string]upstreamModel)
|
||||
upstreamNames := make([]string, 0, len(modelsEnv.Data))
|
||||
for _, m := range modelsEnv.Data {
|
||||
if m.ModelName != "" {
|
||||
modelByName[m.ModelName] = m
|
||||
upstreamNames = append(upstreamNames, m.ModelName)
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 本地已有模型
|
||||
var locals []model.Model
|
||||
if len(upstreamNames) > 0 {
|
||||
_ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error
|
||||
}
|
||||
|
||||
// 本地 vendor 名称映射
|
||||
vendorIdSet := make(map[int]struct{})
|
||||
for _, m := range locals {
|
||||
if m.VendorID != 0 {
|
||||
vendorIdSet[m.VendorID] = struct{}{}
|
||||
}
|
||||
}
|
||||
vendorIDs := make([]int, 0, len(vendorIdSet))
|
||||
for id := range vendorIdSet {
|
||||
vendorIDs = append(vendorIDs, id)
|
||||
}
|
||||
idToVendorName := make(map[int]string)
|
||||
if len(vendorIDs) > 0 {
|
||||
var dbVendors []model.Vendor
|
||||
_ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error
|
||||
for _, v := range dbVendors {
|
||||
idToVendorName[v.Id] = v.Name
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 缺失且上游存在的模型
|
||||
missingList, _ := model.GetMissingModels()
|
||||
var missing []string
|
||||
for _, name := range missingList {
|
||||
if _, ok := modelByName[name]; ok {
|
||||
missing = append(missing, name)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 计算冲突字段
|
||||
type conflictField struct {
|
||||
Field string `json:"field"`
|
||||
Local interface{} `json:"local"`
|
||||
Upstream interface{} `json:"upstream"`
|
||||
}
|
||||
type conflictItem struct {
|
||||
ModelName string `json:"model_name"`
|
||||
Fields []conflictField `json:"fields"`
|
||||
}
|
||||
|
||||
var conflicts []conflictItem
|
||||
for _, local := range locals {
|
||||
up, ok := modelByName[local.ModelName]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
fields := make([]conflictField, 0, 6)
|
||||
if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) {
|
||||
fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description})
|
||||
}
|
||||
if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) {
|
||||
fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon})
|
||||
}
|
||||
if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) {
|
||||
fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags})
|
||||
}
|
||||
// vendor 对比使用名称
|
||||
localVendor := idToVendorName[local.VendorID]
|
||||
if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) {
|
||||
fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName})
|
||||
}
|
||||
if local.NameRule != up.NameRule {
|
||||
fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule})
|
||||
}
|
||||
if local.Status != chooseStatus(up.Status, local.Status) {
|
||||
fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status})
|
||||
}
|
||||
if len(fields) > 0 {
|
||||
conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields})
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": gin.H{
|
||||
"missing": missing,
|
||||
"conflicts": conflicts,
|
||||
"source": gin.H{
|
||||
"locale": locale,
|
||||
"models_url": modelsURL,
|
||||
"vendors_url": vendorsURL,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) {
|
||||
values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret)
|
||||
values.Set("code", code)
|
||||
values.Set("grant_type", "authorization_code")
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress))
|
||||
values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress))
|
||||
formData := values.Encode()
|
||||
req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData))
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -35,8 +36,13 @@ func GetOptions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
type OptionUpdateRequest struct {
|
||||
Key string `json:"key"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
func UpdateOption(c *gin.Context) {
|
||||
var option model.Option
|
||||
var option OptionUpdateRequest
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&option)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
@@ -45,6 +51,16 @@ func UpdateOption(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
switch option.Value.(type) {
|
||||
case bool:
|
||||
option.Value = common.Interface2String(option.Value.(bool))
|
||||
case float64:
|
||||
option.Value = common.Interface2String(option.Value.(float64))
|
||||
case int:
|
||||
option.Value = common.Interface2String(option.Value.(int))
|
||||
default:
|
||||
option.Value = fmt.Sprintf("%v", option.Value)
|
||||
}
|
||||
switch option.Key {
|
||||
case "GitHubOAuthEnabled":
|
||||
if option.Value == "true" && common.GitHubClientId == "" {
|
||||
@@ -104,7 +120,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "GroupRatio":
|
||||
err = ratio_setting.CheckGroupRatio(option.Value)
|
||||
err = ratio_setting.CheckGroupRatio(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -140,7 +156,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "ModelRequestRateLimitGroup":
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value)
|
||||
err = setting.CheckModelRequestRateLimitGroup(option.Value.(string))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -149,7 +165,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.api_info":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -158,7 +174,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.announcements":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "Announcements")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "Announcements")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -167,7 +183,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.faq":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "FAQ")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "FAQ")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -176,7 +192,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
case "console_setting.uptime_kuma_groups":
|
||||
err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups")
|
||||
err = console_setting.ValidateConsoleSettings(option.Value.(string), "UptimeKumaGroups")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -185,7 +201,7 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = model.UpdateOption(option.Key, option.Value)
|
||||
err = model.UpdateOption(option.Key, option.Value.(string))
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
"net/http"
|
||||
"one-api/setting/ratio_setting"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetRatioConfig(c *gin.Context) {
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
if !ratio_setting.IsExposeRatioEnabled() {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "倍率配置接口未启用",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": ratio_setting.GetExposedData(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"one-api/logger"
|
||||
"strings"
|
||||
@@ -21,8 +23,26 @@ const (
|
||||
defaultTimeoutSeconds = 10
|
||||
defaultEndpoint = "/api/ratio_config"
|
||||
maxConcurrentFetches = 8
|
||||
maxRatioConfigBytes = 10 << 20 // 10MB
|
||||
floatEpsilon = 1e-9
|
||||
)
|
||||
|
||||
func nearlyEqual(a, b float64) bool {
|
||||
if a > b {
|
||||
return a-b < floatEpsilon
|
||||
}
|
||||
return b-a < floatEpsilon
|
||||
}
|
||||
|
||||
func valuesEqual(a, b interface{}) bool {
|
||||
af, aok := a.(float64)
|
||||
bf, bok := b.(float64)
|
||||
if aok && bok {
|
||||
return nearlyEqual(af, bf)
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
||||
|
||||
type upstreamResult struct {
|
||||
@@ -87,7 +107,23 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
|
||||
sem := make(chan struct{}, maxConcurrentFetches)
|
||||
|
||||
client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}}
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
// 对 github.io 优先尝试 IPv4,失败则回退 IPv6
|
||||
if strings.HasSuffix(host, "github.io") {
|
||||
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp6", addr)
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
|
||||
for _, chn := range upstreams {
|
||||
wg.Add(1)
|
||||
@@ -98,12 +134,17 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
defer func() { <-sem }()
|
||||
|
||||
endpoint := chItem.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
var fullURL string
|
||||
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
||||
fullURL = endpoint
|
||||
} else {
|
||||
if endpoint == "" {
|
||||
endpoint = defaultEndpoint
|
||||
} else if !strings.HasPrefix(endpoint, "/") {
|
||||
endpoint = "/" + endpoint
|
||||
}
|
||||
fullURL = chItem.BaseURL + endpoint
|
||||
}
|
||||
fullURL := chItem.BaseURL + endpoint
|
||||
|
||||
uniqueName := chItem.Name
|
||||
if chItem.ID != 0 {
|
||||
@@ -120,10 +161,19 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
// 简单重试:最多 3 次,指数退避
|
||||
var resp *http.Response
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
resp, lastErr = client.Do(httpReq)
|
||||
if lastErr == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
||||
}
|
||||
if lastErr != nil {
|
||||
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
@@ -132,6 +182,12 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
||||
return
|
||||
}
|
||||
|
||||
// Content-Type 和响应体大小校验
|
||||
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
|
||||
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
|
||||
}
|
||||
limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
|
||||
// 兼容两种上游接口格式:
|
||||
// type1: /api/ratio_config -> data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price
|
||||
// type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式
|
||||
@@ -141,7 +197,7 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
if err := json.NewDecoder(limited).Decode(&body); err != nil {
|
||||
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
||||
ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
||||
return
|
||||
@@ -152,6 +208,8 @@ func FetchUpstreamRatios(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容)
|
||||
|
||||
// 尝试按 type1 解析
|
||||
var type1Data map[string]any
|
||||
if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
||||
@@ -357,9 +415,9 @@ func buildDifferences(localData map[string]any, successfulChannels []struct {
|
||||
upstreamValue = val
|
||||
hasUpstreamValue = true
|
||||
|
||||
if localValue != nil && localValue != val {
|
||||
if localValue != nil && !valuesEqual(localValue, val) {
|
||||
hasDifference = true
|
||||
} else if localValue == val {
|
||||
} else if valuesEqual(localValue, val) {
|
||||
upstreamValue = "same"
|
||||
}
|
||||
}
|
||||
@@ -466,6 +524,13 @@ func GetSyncableChannels(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
||||
ID: -100,
|
||||
Name: "官方倍率预设",
|
||||
BaseURL: "https://basellm.github.io",
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
|
||||
@@ -139,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
||||
|
||||
// common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta)
|
||||
|
||||
preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||
if newAPIError != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Only return quota if downstream failed and quota was actually pre-consumed
|
||||
if newAPIError != nil && preConsumedQuota != 0 {
|
||||
service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota)
|
||||
if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 {
|
||||
service.ReturnPreConsumedQuota(c, relayInfo)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -277,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
|
||||
|
||||
func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
|
||||
logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
|
||||
|
||||
gopool.Go(func() {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||
if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
|
||||
gopool.Go(func() {
|
||||
service.DisableChannel(channelError, err.Error())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
|
||||
// 保存错误日志到mysql中
|
||||
|
||||
@@ -178,4 +178,4 @@ func boolToString(b bool) string {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
} else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil {
|
||||
return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err)
|
||||
} else {
|
||||
task.Data = responseBody
|
||||
task.Data = redactVideoResponseBody(responseBody)
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -113,11 +113,13 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
task.StartTime = now
|
||||
}
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
task.FinishTime = now
|
||||
}
|
||||
task.FailReason = taskResult.Url
|
||||
if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") {
|
||||
task.FailReason = taskResult.Url
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
@@ -146,3 +148,37 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func redactVideoResponseBody(body []byte) []byte {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(body, &m); err != nil {
|
||||
return body
|
||||
}
|
||||
resp, _ := m["response"].(map[string]any)
|
||||
if resp != nil {
|
||||
delete(resp, "bytesBase64Encoded")
|
||||
if v, ok := resp["video"].(string); ok {
|
||||
resp["video"] = truncateBase64(v)
|
||||
}
|
||||
if vs, ok := resp["videos"].([]any); ok {
|
||||
for i := range vs {
|
||||
if vm, ok := vs[i].(map[string]any); ok {
|
||||
delete(vm, "bytesBase64Encoded")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func truncateBase64(s string) string {
|
||||
const maxKeep = 256
|
||||
if len(s) <= maxKeep {
|
||||
return s
|
||||
}
|
||||
return s[:maxKeep] + "..."
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"one-api/model"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -19,6 +21,44 @@ import (
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
func GetTopUpInfo(c *gin.Context) {
|
||||
// 获取支付方式
|
||||
payMethods := operation_setting.PayMethods
|
||||
|
||||
// 如果启用了 Stripe 支付,添加到支付方法列表
|
||||
if setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "" {
|
||||
// 检查是否已经包含 Stripe
|
||||
hasStripe := false
|
||||
for _, method := range payMethods {
|
||||
if method["type"] == "stripe" {
|
||||
hasStripe = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasStripe {
|
||||
stripeMethod := map[string]string{
|
||||
"name": "Stripe",
|
||||
"type": "stripe",
|
||||
"color": "rgba(var(--semi-purple-5), 1)",
|
||||
"min_topup": strconv.Itoa(setting.StripeMinTopUp),
|
||||
}
|
||||
payMethods = append(payMethods, stripeMethod)
|
||||
}
|
||||
}
|
||||
|
||||
data := gin.H{
|
||||
"enable_online_topup": operation_setting.PayAddress != "" && operation_setting.EpayId != "" && operation_setting.EpayKey != "",
|
||||
"enable_stripe_topup": setting.StripeApiSecret != "" && setting.StripeWebhookSecret != "" && setting.StripePriceId != "",
|
||||
"pay_methods": payMethods,
|
||||
"min_topup": operation_setting.MinTopUp,
|
||||
"stripe_min_topup": setting.StripeMinTopUp,
|
||||
"amount_options": operation_setting.GetPaymentSetting().AmountOptions,
|
||||
"discount": operation_setting.GetPaymentSetting().AmountDiscount,
|
||||
}
|
||||
common.ApiSuccess(c, data)
|
||||
}
|
||||
|
||||
type EpayRequest struct {
|
||||
Amount int64 `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
@@ -31,13 +71,13 @@ type AmountRequest struct {
|
||||
}
|
||||
|
||||
func GetEpayClient() *epay.Client {
|
||||
if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
|
||||
if operation_setting.PayAddress == "" || operation_setting.EpayId == "" || operation_setting.EpayKey == "" {
|
||||
return nil
|
||||
}
|
||||
withUrl, err := epay.NewClient(&epay.Config{
|
||||
PartnerID: setting.EpayId,
|
||||
Key: setting.EpayKey,
|
||||
}, setting.PayAddress)
|
||||
PartnerID: operation_setting.EpayId,
|
||||
Key: operation_setting.EpayKey,
|
||||
}, operation_setting.PayAddress)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -58,15 +98,23 @@ func getPayMoney(amount int64, group string) float64 {
|
||||
}
|
||||
|
||||
dTopupGroupRatio := decimal.NewFromFloat(topupGroupRatio)
|
||||
dPrice := decimal.NewFromFloat(setting.Price)
|
||||
dPrice := decimal.NewFromFloat(operation_setting.Price)
|
||||
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(amount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
dDiscount := decimal.NewFromFloat(discount)
|
||||
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio)
|
||||
payMoney := dAmount.Mul(dPrice).Mul(dTopupGroupRatio).Mul(dDiscount)
|
||||
|
||||
return payMoney.InexactFloat64()
|
||||
}
|
||||
|
||||
func getMinTopup() int64 {
|
||||
minTopup := setting.MinTopUp
|
||||
minTopup := operation_setting.MinTopUp
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
dMinTopup := decimal.NewFromInt(int64(minTopup))
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
@@ -99,13 +147,13 @@ func RequestEpay(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if !setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
if !operation_setting.ContainsPayMethod(req.PaymentMethod) {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "支付方式不存在"})
|
||||
return
|
||||
}
|
||||
|
||||
callBackAddress := service.GetCallbackAddress()
|
||||
returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log")
|
||||
returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log")
|
||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -215,8 +217,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i
|
||||
|
||||
params := &stripe.CheckoutSessionParams{
|
||||
ClientReferenceID: stripe.String(referenceId),
|
||||
SuccessURL: stripe.String(setting.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(setting.ServerAddress + "/topup"),
|
||||
SuccessURL: stripe.String(system_setting.ServerAddress + "/log"),
|
||||
CancelURL: stripe.String(system_setting.ServerAddress + "/topup"),
|
||||
LineItems: []*stripe.CheckoutSessionLineItemParams{
|
||||
{
|
||||
Price: stripe.String(setting.StripePriceId),
|
||||
@@ -254,6 +256,7 @@ func GetChargedAmount(count float64, user model.User) float64 {
|
||||
}
|
||||
|
||||
func getStripePayMoney(amount float64, group string) float64 {
|
||||
originalAmount := amount
|
||||
if !common.DisplayInCurrencyEnabled {
|
||||
amount = amount / common.QuotaPerUnit
|
||||
}
|
||||
@@ -262,7 +265,14 @@ func getStripePayMoney(amount float64, group string) float64 {
|
||||
if topupGroupRatio == 0 {
|
||||
topupGroupRatio = 1
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio
|
||||
// apply optional preset discount by the original request amount (if configured), default 1.0
|
||||
discount := 1.0
|
||||
if ds, ok := operation_setting.GetPaymentSetting().AmountDiscount[int(originalAmount)]; ok {
|
||||
if ds > 0 {
|
||||
discount = ds
|
||||
}
|
||||
}
|
||||
payMoney := amount * setting.StripeUnitPrice * topupGroupRatio * discount
|
||||
return payMoney
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ type Monitor struct {
|
||||
|
||||
type UptimeGroupResult struct {
|
||||
CategoryName string `json:"categoryName"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
Monitors []Monitor `json:"monitors"`
|
||||
}
|
||||
|
||||
func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error {
|
||||
@@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
||||
url, _ := groupConfig["url"].(string)
|
||||
slug, _ := groupConfig["slug"].(string)
|
||||
categoryName, _ := groupConfig["categoryName"].(string)
|
||||
|
||||
|
||||
result := UptimeGroupResult{
|
||||
CategoryName: categoryName,
|
||||
Monitors: []Monitor{},
|
||||
Monitors: []Monitor{},
|
||||
}
|
||||
|
||||
|
||||
if url == "" || slug == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(url, "/")
|
||||
|
||||
|
||||
var statusData struct {
|
||||
PublicGroupList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MonitorList []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"monitorList"`
|
||||
} `json:"publicGroupList"`
|
||||
}
|
||||
|
||||
|
||||
var heartbeatData struct {
|
||||
HeartbeatList map[string][]struct {
|
||||
Status int `json:"status"`
|
||||
@@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st
|
||||
}
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData)
|
||||
})
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
g.Go(func() error {
|
||||
return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData)
|
||||
})
|
||||
|
||||
if g.Wait() != nil {
|
||||
@@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
||||
|
||||
client := &http.Client{Timeout: httpTimeout}
|
||||
results := make([]UptimeGroupResult, len(groups))
|
||||
|
||||
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
for i, group := range groups {
|
||||
i, group := i, group
|
||||
@@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
g.Wait()
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +210,7 @@ func Register(c *gin.Context) {
|
||||
Password: user.Password,
|
||||
DisplayName: user.Username,
|
||||
InviterId: inviterId,
|
||||
Role: common.RoleCommonUser, // 明确设置角色为普通用户
|
||||
}
|
||||
if common.EmailVerificationEnabled {
|
||||
cleanUser.Email = user.Email
|
||||
@@ -426,6 +427,7 @@ func GetAffCode(c *gin.Context) {
|
||||
|
||||
func GetSelf(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
userRole := c.GetInt("role")
|
||||
user, err := model.GetUserById(id, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -434,14 +436,134 @@ func GetSelf(c *gin.Context) {
|
||||
// Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users
|
||||
user.Remark = ""
|
||||
|
||||
// 计算用户权限信息
|
||||
permissions := calculateUserPermissions(userRole)
|
||||
|
||||
// 获取用户设置并提取sidebar_modules
|
||||
userSetting := user.GetSetting()
|
||||
|
||||
// 构建响应数据,包含用户信息和权限
|
||||
responseData := map[string]interface{}{
|
||||
"id": user.Id,
|
||||
"username": user.Username,
|
||||
"display_name": user.DisplayName,
|
||||
"role": user.Role,
|
||||
"status": user.Status,
|
||||
"email": user.Email,
|
||||
"group": user.Group,
|
||||
"quota": user.Quota,
|
||||
"used_quota": user.UsedQuota,
|
||||
"request_count": user.RequestCount,
|
||||
"aff_code": user.AffCode,
|
||||
"aff_count": user.AffCount,
|
||||
"aff_quota": user.AffQuota,
|
||||
"aff_history_quota": user.AffHistoryQuota,
|
||||
"inviter_id": user.InviterId,
|
||||
"linux_do_id": user.LinuxDOId,
|
||||
"setting": user.Setting,
|
||||
"stripe_customer": user.StripeCustomer,
|
||||
"sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段
|
||||
"permissions": permissions, // 新增权限字段
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": user,
|
||||
"data": responseData,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 计算用户权限的辅助函数
|
||||
func calculateUserPermissions(userRole int) map[string]interface{} {
|
||||
permissions := map[string]interface{}{}
|
||||
|
||||
// 根据用户角色计算权限
|
||||
if userRole == common.RoleRootUser {
|
||||
// 超级管理员不需要边栏设置功能
|
||||
permissions["sidebar_settings"] = false
|
||||
permissions["sidebar_modules"] = map[string]interface{}{}
|
||||
} else if userRole == common.RoleAdminUser {
|
||||
// 管理员可以设置边栏,但不包含系统设置功能
|
||||
permissions["sidebar_settings"] = true
|
||||
permissions["sidebar_modules"] = map[string]interface{}{
|
||||
"admin": map[string]interface{}{
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// 普通用户只能设置个人功能,不包含管理员区域
|
||||
permissions["sidebar_settings"] = true
|
||||
permissions["sidebar_modules"] = map[string]interface{}{
|
||||
"admin": false, // 普通用户不能访问管理员区域
|
||||
}
|
||||
}
|
||||
|
||||
return permissions
|
||||
}
|
||||
|
||||
// 根据用户角色生成默认的边栏配置
|
||||
func generateDefaultSidebarConfig(userRole int) string {
|
||||
defaultConfig := map[string]interface{}{}
|
||||
|
||||
// 聊天区域 - 所有用户都可以访问
|
||||
defaultConfig["chat"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"playground": true,
|
||||
"chat": true,
|
||||
}
|
||||
|
||||
// 控制台区域 - 所有用户都可以访问
|
||||
defaultConfig["console"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"detail": true,
|
||||
"token": true,
|
||||
"log": true,
|
||||
"midjourney": true,
|
||||
"task": true,
|
||||
}
|
||||
|
||||
// 个人中心区域 - 所有用户都可以访问
|
||||
defaultConfig["personal"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"topup": true,
|
||||
"personal": true,
|
||||
}
|
||||
|
||||
// 管理员区域 - 根据角色决定
|
||||
if userRole == common.RoleAdminUser {
|
||||
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
}
|
||||
} else if userRole == common.RoleRootUser {
|
||||
// 超级管理员可以访问所有功能
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": true,
|
||||
}
|
||||
}
|
||||
// 普通用户不包含admin区域
|
||||
|
||||
// 转换为JSON字符串
|
||||
configBytes, err := json.Marshal(defaultConfig)
|
||||
if err != nil {
|
||||
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(configBytes)
|
||||
}
|
||||
|
||||
func GetUserModels(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -528,8 +650,8 @@ func UpdateUser(c *gin.Context) {
|
||||
}
|
||||
|
||||
func UpdateSelf(c *gin.Context) {
|
||||
var user model.User
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&user)
|
||||
var requestData map[string]interface{}
|
||||
err := json.NewDecoder(c.Request.Body).Decode(&requestData)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
@@ -537,6 +659,60 @@ func UpdateSelf(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否是sidebar_modules更新请求
|
||||
if sidebarModules, exists := requestData["sidebar_modules"]; exists {
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, false)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取当前用户设置
|
||||
currentSetting := user.GetSetting()
|
||||
|
||||
// 更新sidebar_modules字段
|
||||
if sidebarModulesStr, ok := sidebarModules.(string); ok {
|
||||
currentSetting.SidebarModules = sidebarModulesStr
|
||||
}
|
||||
|
||||
// 保存更新后的设置
|
||||
user.SetSetting(currentSetting)
|
||||
if err := user.Update(false); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "更新设置失败: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "设置更新成功",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 原有的用户信息更新逻辑
|
||||
var user model.User
|
||||
requestDataBytes, err := json.Marshal(requestData)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(requestDataBytes, &user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的参数",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if user.Password == "" {
|
||||
user.Password = "$I_LOVE_U" // make Validator happy :)
|
||||
}
|
||||
@@ -679,6 +855,7 @@ func CreateUser(c *gin.Context) {
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role, // 保持管理员设置的角色
|
||||
}
|
||||
if err := cleanUser.Insert(0); err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -920,6 +1097,7 @@ type UpdateUserSettingRequest struct {
|
||||
WebhookUrl string `json:"webhook_url,omitempty"`
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"`
|
||||
NotificationEmail string `json:"notification_email,omitempty"`
|
||||
BarkUrl string `json:"bark_url,omitempty"`
|
||||
AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"`
|
||||
RecordIpLog bool `json:"record_ip_log"`
|
||||
}
|
||||
@@ -935,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 验证预警类型
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook {
|
||||
if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的预警类型",
|
||||
@@ -983,6 +1161,33 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是Bark类型,验证Bark URL
|
||||
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||
if req.BarkUrl == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Bark推送URL不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 验证URL格式
|
||||
if _, err := url.ParseRequestURI(req.BarkUrl); err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的Bark推送URL",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 检查是否是HTTP或HTTPS
|
||||
if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "Bark推送URL必须以http://或https://开头",
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userId := c.GetInt("id")
|
||||
user, err := model.GetUserById(userId, true)
|
||||
if err != nil {
|
||||
@@ -1011,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) {
|
||||
settings.NotificationEmail = req.NotificationEmail
|
||||
}
|
||||
|
||||
// 如果是Bark类型,添加Bark URL到设置中
|
||||
if req.QuotaWarningType == dto.NotifyTypeBark {
|
||||
settings.BarkUrl = req.BarkUrl
|
||||
}
|
||||
|
||||
// 更新用户设置
|
||||
user.SetSetting(settings)
|
||||
if err := user.Update(false); err != nil {
|
||||
|
||||
@@ -9,6 +9,14 @@ type ChannelSettings struct {
|
||||
SystemPromptOverride bool `json:"system_prompt_override,omitempty"`
|
||||
}
|
||||
|
||||
type VertexKeyType string
|
||||
|
||||
const (
|
||||
VertexKeyTypeJSON VertexKeyType = "json"
|
||||
VertexKeyTypeAPIKey VertexKeyType = "api_key"
|
||||
)
|
||||
|
||||
type ChannelOtherSettings struct {
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
AzureResponsesVersion string `json:"azure_responses_version,omitempty"`
|
||||
VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key"
|
||||
}
|
||||
|
||||
@@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 序列化时需要重新把字段平铺
|
||||
func (r ImageRequest) MarshalJSON() ([]byte, error) {
|
||||
// 将已定义字段转为 map
|
||||
type Alias ImageRequest
|
||||
alias := Alias(r)
|
||||
base, err := common.Marshal(alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var baseMap map[string]json.RawMessage
|
||||
if err := common.Unmarshal(base, &baseMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并 ExtraFields
|
||||
for k, v := range r.Extra {
|
||||
if _, exists := baseMap[k]; !exists {
|
||||
baseMap[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(baseMap)
|
||||
}
|
||||
|
||||
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
|
||||
fields := make(map[string]struct{})
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
package dto
|
||||
|
||||
type UpstreamDTO struct {
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
ID int `json:"id,omitempty"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
type UpstreamRequest struct {
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
ChannelIDs []int64 `json:"channel_ids"`
|
||||
Upstreams []UpstreamDTO `json:"upstreams"`
|
||||
Timeout int `json:"timeout"`
|
||||
}
|
||||
|
||||
// TestResult 上游测试连通性结果
|
||||
type TestResult struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// DifferenceItem 差异项
|
||||
@@ -25,14 +25,14 @@ type TestResult struct {
|
||||
// Upstreams 为各渠道的上游值,具体数值 / "same" / nil
|
||||
|
||||
type DifferenceItem struct {
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
Current interface{} `json:"current"`
|
||||
Upstreams map[string]interface{} `json:"upstreams"`
|
||||
Confidence map[string]bool `json:"confidence"`
|
||||
}
|
||||
|
||||
type SyncableChannel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
BaseURL string `json:"base_url"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
@@ -6,11 +6,14 @@ type UserSetting struct {
|
||||
WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址
|
||||
WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥
|
||||
NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址
|
||||
BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL
|
||||
AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型
|
||||
RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP
|
||||
SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyTypeEmail = "email" // Email 邮件
|
||||
NotifyTypeWebhook = "webhook" // Webhook
|
||||
NotifyTypeBark = "bark" // Bark 推送
|
||||
)
|
||||
|
||||
12
main.go
12
main.go
@@ -94,13 +94,9 @@ func main() {
|
||||
}
|
||||
go controller.AutomaticallyUpdateChannels(frequency)
|
||||
}
|
||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||
if err != nil {
|
||||
common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
|
||||
}
|
||||
go controller.AutomaticallyTestChannels(frequency)
|
||||
}
|
||||
|
||||
go controller.AutomaticallyTestChannels()
|
||||
|
||||
if common.IsMasterNode && constant.UpdateTask {
|
||||
gopool.Go(func() {
|
||||
controller.UpdateMidjourneyTaskBulk()
|
||||
@@ -208,4 +204,4 @@ func InitResources() error {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -166,9 +166,9 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
relayMode := relayconstant.RelayModeUnknown
|
||||
if c.Request.Method == http.MethodPost {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
relayMode = relayconstant.RelayModeVideoSubmit
|
||||
} else if c.Request.Method == http.MethodGet {
|
||||
relayMode = relayconstant.RelayModeVideoFetchByID
|
||||
|
||||
@@ -18,12 +18,12 @@ func StatsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 增加活跃连接数
|
||||
atomic.AddInt64(&globalStats.activeConnections, 1)
|
||||
|
||||
|
||||
// 确保在请求结束时减少连接数
|
||||
defer func() {
|
||||
atomic.AddInt64(&globalStats.activeConnections, -1)
|
||||
}()
|
||||
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -38,4 +38,4 @@ func GetStats() StatsInfo {
|
||||
return StatsInfo{
|
||||
ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,14 +42,16 @@ type Channel struct {
|
||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||
OtherInfo string `json:"other_info"`
|
||||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置
|
||||
Tag *string `json:"tag" gorm:"index"`
|
||||
Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置
|
||||
ParamOverride *string `json:"param_override" gorm:"type:text"`
|
||||
HeaderOverride *string `json:"header_override" gorm:"type:text"`
|
||||
Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"`
|
||||
// add after v0.8.5
|
||||
ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"`
|
||||
|
||||
OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings
|
||||
|
||||
// cache info
|
||||
Keys []string `json:"-" gorm:"-"`
|
||||
}
|
||||
@@ -606,8 +608,12 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
return false
|
||||
}
|
||||
if channelCache.ChannelInfo.IsMultiKey {
|
||||
// Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey
|
||||
pollingLock := GetChannelPollingLock(channelId)
|
||||
pollingLock.Lock()
|
||||
// 如果是多Key模式,更新缓存中的状态
|
||||
handlerMultiKeyUpdate(channelCache, usingKey, status, reason)
|
||||
pollingLock.Unlock()
|
||||
//CacheUpdateChannel(channelCache)
|
||||
//return true
|
||||
} else {
|
||||
@@ -638,7 +644,11 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri
|
||||
|
||||
if channel.ChannelInfo.IsMultiKey {
|
||||
beforeStatus := channel.Status
|
||||
// Protect map writes with the same per-channel lock used by readers
|
||||
pollingLock := GetChannelPollingLock(channelId)
|
||||
pollingLock.Lock()
|
||||
handlerMultiKeyUpdate(channel, usingKey, status, reason)
|
||||
pollingLock.Unlock()
|
||||
if beforeStatus != channel.Status {
|
||||
shouldUpdateAbilities = true
|
||||
}
|
||||
|
||||
@@ -64,22 +64,6 @@ var DB *gorm.DB
|
||||
|
||||
var LOG_DB *gorm.DB
|
||||
|
||||
// dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors
|
||||
func dropIndexIfExists(tableName string, indexName string) {
|
||||
if !common.UsingMySQL {
|
||||
return
|
||||
}
|
||||
var count int64
|
||||
// Check index existence via information_schema
|
||||
err := DB.Raw(
|
||||
"SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?",
|
||||
tableName, indexName,
|
||||
).Scan(&count).Error
|
||||
if err == nil && count > 0 {
|
||||
_ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error
|
||||
}
|
||||
}
|
||||
|
||||
func createRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != common.UserStatusEnabled {
|
||||
@@ -263,16 +247,6 @@ func InitLogDB() (err error) {
|
||||
}
|
||||
|
||||
func migrateDB() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引 (model_name, deleted_at) 冲突
|
||||
dropIndexIfExists("models", "uk_model_name") // 新版复合索引名称(若已存在)
|
||||
dropIndexIfExists("models", "model_name") // 旧版列级唯一索引名称
|
||||
|
||||
dropIndexIfExists("vendors", "uk_vendor_name") // 新版复合索引名称(若已存在)
|
||||
dropIndexIfExists("vendors", "name") // 旧版列级唯一索引名称
|
||||
//if !common.UsingPostgreSQL {
|
||||
// return migrateDBFast()
|
||||
//}
|
||||
err := DB.AutoMigrate(
|
||||
&Channel{},
|
||||
&Token{},
|
||||
@@ -299,13 +273,6 @@ func migrateDB() error {
|
||||
}
|
||||
|
||||
func migrateDBFast() error {
|
||||
// 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录
|
||||
// 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引冲突
|
||||
dropIndexIfExists("models", "uk_model_name")
|
||||
dropIndexIfExists("models", "model_name")
|
||||
|
||||
dropIndexIfExists("vendors", "uk_vendor_name")
|
||||
dropIndexIfExists("vendors", "name")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
|
||||
@@ -20,17 +20,18 @@ type BoundChannel struct {
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Id int `json:"id"`
|
||||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"`
|
||||
Id int `json:"id"`
|
||||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
SyncOfficial int `json:"sync_official" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
|
||||
|
||||
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||||
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"one-api/setting/config"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -66,16 +67,16 @@ func InitOptionMap() {
|
||||
common.OptionMap["SystemName"] = common.SystemName
|
||||
common.OptionMap["Logo"] = common.Logo
|
||||
common.OptionMap["ServerAddress"] = ""
|
||||
common.OptionMap["WorkerUrl"] = setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
|
||||
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled)
|
||||
common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl
|
||||
common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey
|
||||
common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled)
|
||||
common.OptionMap["PayAddress"] = ""
|
||||
common.OptionMap["CustomCallbackAddress"] = ""
|
||||
common.OptionMap["EpayId"] = ""
|
||||
common.OptionMap["EpayKey"] = ""
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(setting.USDExchangeRate, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
|
||||
common.OptionMap["Price"] = strconv.FormatFloat(operation_setting.Price, 'f', -1, 64)
|
||||
common.OptionMap["USDExchangeRate"] = strconv.FormatFloat(operation_setting.USDExchangeRate, 'f', -1, 64)
|
||||
common.OptionMap["MinTopUp"] = strconv.Itoa(operation_setting.MinTopUp)
|
||||
common.OptionMap["StripeMinTopUp"] = strconv.Itoa(setting.StripeMinTopUp)
|
||||
common.OptionMap["StripeApiSecret"] = setting.StripeApiSecret
|
||||
common.OptionMap["StripeWebhookSecret"] = setting.StripeWebhookSecret
|
||||
@@ -85,7 +86,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["Chats"] = setting.Chats2JsonString()
|
||||
common.OptionMap["AutoGroups"] = setting.AutoGroups2JsonString()
|
||||
common.OptionMap["DefaultUseAutoGroup"] = strconv.FormatBool(setting.DefaultUseAutoGroup)
|
||||
common.OptionMap["PayMethods"] = setting.PayMethods2JsonString()
|
||||
common.OptionMap["PayMethods"] = operation_setting.PayMethods2JsonString()
|
||||
common.OptionMap["GitHubClientId"] = ""
|
||||
common.OptionMap["GitHubClientSecret"] = ""
|
||||
common.OptionMap["TelegramBotToken"] = ""
|
||||
@@ -274,7 +275,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPSSLEnabled":
|
||||
common.SMTPSSLEnabled = boolValue
|
||||
case "WorkerAllowHttpImageRequestEnabled":
|
||||
setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
system_setting.WorkerAllowHttpImageRequestEnabled = boolValue
|
||||
case "DefaultUseAutoGroup":
|
||||
setting.DefaultUseAutoGroup = boolValue
|
||||
case "ExposeRatioEnabled":
|
||||
@@ -296,29 +297,29 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "SMTPToken":
|
||||
common.SMTPToken = value
|
||||
case "ServerAddress":
|
||||
setting.ServerAddress = value
|
||||
system_setting.ServerAddress = value
|
||||
case "WorkerUrl":
|
||||
setting.WorkerUrl = value
|
||||
system_setting.WorkerUrl = value
|
||||
case "WorkerValidKey":
|
||||
setting.WorkerValidKey = value
|
||||
system_setting.WorkerValidKey = value
|
||||
case "PayAddress":
|
||||
setting.PayAddress = value
|
||||
operation_setting.PayAddress = value
|
||||
case "Chats":
|
||||
err = setting.UpdateChatsByJsonString(value)
|
||||
case "AutoGroups":
|
||||
err = setting.UpdateAutoGroupsByJsonString(value)
|
||||
case "CustomCallbackAddress":
|
||||
setting.CustomCallbackAddress = value
|
||||
operation_setting.CustomCallbackAddress = value
|
||||
case "EpayId":
|
||||
setting.EpayId = value
|
||||
operation_setting.EpayId = value
|
||||
case "EpayKey":
|
||||
setting.EpayKey = value
|
||||
operation_setting.EpayKey = value
|
||||
case "Price":
|
||||
setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
operation_setting.Price, _ = strconv.ParseFloat(value, 64)
|
||||
case "USDExchangeRate":
|
||||
setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||
operation_setting.USDExchangeRate, _ = strconv.ParseFloat(value, 64)
|
||||
case "MinTopUp":
|
||||
setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
operation_setting.MinTopUp, _ = strconv.Atoi(value)
|
||||
case "StripeApiSecret":
|
||||
setting.StripeApiSecret = value
|
||||
case "StripeWebhookSecret":
|
||||
@@ -422,7 +423,7 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
case "StreamCacheQueueLength":
|
||||
setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
|
||||
case "PayMethods":
|
||||
err = setting.UpdatePayMethodsByJsonString(value)
|
||||
err = operation_setting.UpdatePayMethodsByJsonString(value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type TwoFA struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"unique;not null;index"`
|
||||
Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端
|
||||
IsEnabled bool `json:"is_enabled" gorm:"default:false"`
|
||||
IsEnabled bool `json:"is_enabled"`
|
||||
FailedAttempts int `json:"failed_attempts" gorm:"default:0"`
|
||||
LockedUntil *time.Time `json:"locked_until,omitempty"`
|
||||
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||
@@ -30,7 +30,7 @@ type TwoFABackupCode struct {
|
||||
Id int `json:"id" gorm:"primaryKey"`
|
||||
UserId int `json:"user_id" gorm:"not null;index"`
|
||||
CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希
|
||||
IsUsed bool `json:"is_used" gorm:"default:false"`
|
||||
IsUsed bool `json:"is_used"`
|
||||
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index"`
|
||||
|
||||
@@ -91,6 +91,68 @@ func (user *User) SetSetting(setting dto.UserSetting) {
|
||||
user.Setting = string(settingBytes)
|
||||
}
|
||||
|
||||
// 根据用户角色生成默认的边栏配置
|
||||
func generateDefaultSidebarConfigForRole(userRole int) string {
|
||||
defaultConfig := map[string]interface{}{}
|
||||
|
||||
// 聊天区域 - 所有用户都可以访问
|
||||
defaultConfig["chat"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"playground": true,
|
||||
"chat": true,
|
||||
}
|
||||
|
||||
// 控制台区域 - 所有用户都可以访问
|
||||
defaultConfig["console"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"detail": true,
|
||||
"token": true,
|
||||
"log": true,
|
||||
"midjourney": true,
|
||||
"task": true,
|
||||
}
|
||||
|
||||
// 个人中心区域 - 所有用户都可以访问
|
||||
defaultConfig["personal"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"topup": true,
|
||||
"personal": true,
|
||||
}
|
||||
|
||||
// 管理员区域 - 根据角色决定
|
||||
if userRole == common.RoleAdminUser {
|
||||
// 管理员可以访问管理员区域,但不能访问系统设置
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": false, // 管理员不能访问系统设置
|
||||
}
|
||||
} else if userRole == common.RoleRootUser {
|
||||
// 超级管理员可以访问所有功能
|
||||
defaultConfig["admin"] = map[string]interface{}{
|
||||
"enabled": true,
|
||||
"channel": true,
|
||||
"models": true,
|
||||
"redemption": true,
|
||||
"user": true,
|
||||
"setting": true,
|
||||
}
|
||||
}
|
||||
// 普通用户不包含admin区域
|
||||
|
||||
// 转换为JSON字符串
|
||||
configBytes, err := json.Marshal(defaultConfig)
|
||||
if err != nil {
|
||||
common.SysLog("生成默认边栏配置失败: " + err.Error())
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(configBytes)
|
||||
}
|
||||
|
||||
// CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil
|
||||
func CheckUserExistOrDeleted(username string, email string) (bool, error) {
|
||||
var user User
|
||||
@@ -320,10 +382,34 @@ func (user *User) Insert(inviterId int) error {
|
||||
user.Quota = common.QuotaForNewUser
|
||||
//user.SetAccessToken(common.GetUUID())
|
||||
user.AffCode = common.GetRandomString(4)
|
||||
|
||||
// 初始化用户设置,包括默认的边栏配置
|
||||
if user.Setting == "" {
|
||||
defaultSetting := dto.UserSetting{}
|
||||
// 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置
|
||||
user.SetSetting(defaultSetting)
|
||||
}
|
||||
|
||||
result := DB.Create(user)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// 用户创建成功后,根据角色初始化边栏配置
|
||||
// 需要重新获取用户以确保有正确的ID和Role
|
||||
var createdUser User
|
||||
if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil {
|
||||
// 生成基于角色的默认边栏配置
|
||||
defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role)
|
||||
if defaultSidebarConfig != "" {
|
||||
currentSetting := createdUser.GetSetting()
|
||||
currentSetting.SidebarModules = defaultSidebarConfig
|
||||
createdUser.SetSetting(currentSetting)
|
||||
createdUser.Update(false)
|
||||
common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role))
|
||||
}
|
||||
}
|
||||
|
||||
if common.QuotaForNewUser > 0 {
|
||||
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser)))
|
||||
}
|
||||
|
||||
@@ -14,13 +14,13 @@ import (
|
||||
|
||||
type Vendor struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"`
|
||||
Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name_delete_at,priority:1"`
|
||||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||||
Status int `json:"status" gorm:"default:1"`
|
||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"`
|
||||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name_delete_at,priority:2"`
|
||||
}
|
||||
|
||||
// Insert 创建新的供应商记录
|
||||
|
||||
@@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
|
||||
@@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
// 检查是否为Nova模型
|
||||
if isNovaModel(request.Model) {
|
||||
novaReq := convertToNovaRequest(request)
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", novaReq)
|
||||
c.Set("is_nova_model", true)
|
||||
return novaReq, nil
|
||||
}
|
||||
|
||||
// 原有的Claude模型处理逻辑
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
@@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
c.Set("is_nova_model", false)
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package aws
|
||||
|
||||
import "strings"
|
||||
|
||||
var awsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
@@ -14,6 +16,11 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -58,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
// Nova models - all support three major regions
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-lite-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-pro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-premier-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
}}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
@@ -67,3 +94,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
|
||||
// 判断是否为Nova模型
|
||||
func isNovaModel(modelId string) bool {
|
||||
return strings.HasPrefix(modelId, "nova-")
|
||||
}
|
||||
|
||||
@@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
Thinking: req.Thinking,
|
||||
}
|
||||
}
|
||||
|
||||
// NovaMessage Nova模型使用messages-v1格式
|
||||
type NovaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []NovaContent `json:"content"`
|
||||
}
|
||||
|
||||
type NovaContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type NovaRequest struct {
|
||||
SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
|
||||
Messages []NovaMessage `json:"messages"` // 对话消息列表
|
||||
InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
|
||||
}
|
||||
|
||||
type NovaInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
|
||||
Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
|
||||
TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
|
||||
TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
|
||||
StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
|
||||
}
|
||||
|
||||
// 转换OpenAI请求为Nova格式
|
||||
func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
|
||||
novaMessages := make([]NovaMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
novaMessages[i] = NovaMessage{
|
||||
Role: msg.Role,
|
||||
Content: []NovaContent{{Text: msg.StringContent()}},
|
||||
}
|
||||
}
|
||||
|
||||
novaReq := &NovaRequest{
|
||||
SchemaVersion: "messages-v1",
|
||||
Messages: novaMessages,
|
||||
}
|
||||
|
||||
// 设置推理配置
|
||||
if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
|
||||
novaReq.InferenceConfig = &NovaInferenceConfig{}
|
||||
if req.MaxTokens != 0 {
|
||||
novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
|
||||
}
|
||||
if req.Temperature != nil && *req.Temperature != 0 {
|
||||
novaReq.InferenceConfig.Temperature = *req.Temperature
|
||||
}
|
||||
if req.TopP != 0 {
|
||||
novaReq.InferenceConfig.TopP = req.TopP
|
||||
}
|
||||
if req.TopK != 0 {
|
||||
novaReq.InferenceConfig.TopK = req.TopK
|
||||
}
|
||||
if req.Stop != nil {
|
||||
if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
|
||||
novaReq.InferenceConfig.StopSequences = stopSequences
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return novaReq
|
||||
}
|
||||
|
||||
// parseStopSequences 解析停止序列,支持字符串或字符串数组
|
||||
func parseStopSequences(stop any) []string {
|
||||
if stop == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := stop.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
return []string{v}
|
||||
}
|
||||
case []string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var sequences []string
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok && str != "" {
|
||||
sequences = append(sequences, str)
|
||||
}
|
||||
}
|
||||
return sequences
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
}
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
// 检查是否为Nova模型
|
||||
isNova, _ := c.Get("is_nova_model")
|
||||
if isNova == true {
|
||||
// Nova模型也支持跨区域
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||
}
|
||||
|
||||
// 原有的Claude处理逻辑
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
@@ -130,7 +143,12 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
Usage: &dto.Usage{},
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
|
||||
// 复制上游 Content-Type 到客户端响应头
|
||||
if awsResp.ContentType != nil && *awsResp.ContentType != "" {
|
||||
c.Writer.Header().Set("Content-Type", *awsResp.ContentType)
|
||||
}
|
||||
|
||||
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, nil, awsResp.Body, RequestModeMessage)
|
||||
if handlerErr != nil {
|
||||
return handlerErr, nil
|
||||
}
|
||||
@@ -204,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||
novaReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
novaReq := novaReq_.(*NovaRequest)
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
// 解析Nova响应
|
||||
var novaResp struct {
|
||||
Output struct {
|
||||
Message struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"output"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"inputTokens"`
|
||||
OutputTokens int `json:"outputTokens"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
|
||||
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
// 构造OpenAI格式响应
|
||||
response := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
Choices: []dto.OpenAITextResponseChoice{{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: novaResp.Output.Message.Content[0].Text,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}},
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: novaResp.Usage.InputTokens,
|
||||
CompletionTokens: novaResp.Usage.OutputTokens,
|
||||
TotalTokens: novaResp.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
return nil, &response.Usage
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
case "end_turn":
|
||||
return "stop"
|
||||
case "max_tokens":
|
||||
return "max_tokens"
|
||||
return "length"
|
||||
case "tool_use":
|
||||
return "tool_calls"
|
||||
default:
|
||||
@@ -274,19 +274,28 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
|
||||
claudeMessages := make([]dto.ClaudeMessage, 0)
|
||||
isFirstMessage := true
|
||||
// 初始化system消息数组,用于累积多个system消息
|
||||
var systemMessages []dto.ClaudeMediaMessage
|
||||
|
||||
for _, message := range formatMessages {
|
||||
if message.Role == "system" {
|
||||
// 根据Claude API规范,system字段使用数组格式更有通用性
|
||||
if message.IsStringContent() {
|
||||
claudeRequest.System = message.StringContent()
|
||||
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](message.StringContent()),
|
||||
})
|
||||
} else {
|
||||
contents := message.ParseContent()
|
||||
content := ""
|
||||
for _, ctx := range contents {
|
||||
// 支持复合内容的system消息(虽然不常见,但需要考虑完整性)
|
||||
for _, ctx := range message.ParseContent() {
|
||||
if ctx.Type == "text" {
|
||||
content += ctx.Text
|
||||
systemMessages = append(systemMessages, dto.ClaudeMediaMessage{
|
||||
Type: "text",
|
||||
Text: common.GetPointer[string](ctx.Text),
|
||||
})
|
||||
}
|
||||
// 未来可以在这里扩展对图片等其他类型的支持
|
||||
}
|
||||
claudeRequest.System = content
|
||||
}
|
||||
} else {
|
||||
if isFirstMessage {
|
||||
@@ -392,6 +401,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe
|
||||
claudeMessages = append(claudeMessages, claudeMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置累积的system消息
|
||||
if len(systemMessages) > 0 {
|
||||
claudeRequest.System = systemMessages
|
||||
}
|
||||
|
||||
claudeRequest.Prompt = ""
|
||||
claudeRequest.Messages = claudeMessages
|
||||
return &claudeRequest, nil
|
||||
@@ -426,7 +441,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse
|
||||
choice.Delta.Role = "assistant"
|
||||
} else if claudeResponse.Type == "content_block_start" {
|
||||
if claudeResponse.ContentBlock != nil {
|
||||
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
|
||||
// 如果是文本块,尽可能发送首段文本(若存在)
|
||||
if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil {
|
||||
choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text)
|
||||
}
|
||||
if claudeResponse.ContentBlock.Type == "tool_use" {
|
||||
tools = append(tools, dto.ToolCallResponse{
|
||||
Index: common.GetPointer(fcIdx),
|
||||
@@ -698,7 +716,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
||||
return claudeInfo.Usage, nil
|
||||
}
|
||||
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
|
||||
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, httpResp *http.Response, data []byte, requestMode int) *types.NewAPIError {
|
||||
var claudeResponse dto.ClaudeResponse
|
||||
err := common.Unmarshal(data, &claudeResponse)
|
||||
if err != nil {
|
||||
@@ -736,7 +754,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
||||
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
|
||||
}
|
||||
|
||||
service.IOCopyBytesGracefully(c, nil, responseData)
|
||||
service.IOCopyBytesGracefully(c, httpResp, responseData)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -757,7 +775,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
||||
if common.DebugEnabled {
|
||||
println("responseBody: ", string(responseBody))
|
||||
}
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
|
||||
handleErr := HandleClaudeResponseData(c, info, claudeInfo, resp, responseBody, requestMode)
|
||||
if handleErr != nil {
|
||||
return nil, handleErr
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference?hl=zh-cn#blob
|
||||
var geminiSupportedMimeTypes = map[string]bool{
|
||||
"application/pdf": true,
|
||||
"audio/mpeg": true,
|
||||
@@ -30,6 +31,7 @@ var geminiSupportedMimeTypes = map[string]bool{
|
||||
"audio/wav": true,
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
"image/webp": true,
|
||||
"text/plain": true,
|
||||
"video/mov": true,
|
||||
"video/mpeg": true,
|
||||
|
||||
@@ -6,4 +6,4 @@ var ModelList = []string{
|
||||
"m3e-small",
|
||||
}
|
||||
|
||||
var ChannelName = "mokaai"
|
||||
var ChannelName = "mokaai"
|
||||
|
||||
@@ -2,6 +2,7 @@ package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@@ -280,11 +281,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
|
||||
defer service.CloseResponseBodyGracefully(resp)
|
||||
|
||||
// count tokens by audio file duration
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
@@ -292,6 +288,26 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
// 写入新的 response body
|
||||
service.IOCopyBytesGracefully(c, resp, responseBody)
|
||||
|
||||
var responseData struct {
|
||||
Usage *dto.Usage `json:"usage"`
|
||||
}
|
||||
if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
|
||||
if responseData.Usage.TotalTokens > 0 {
|
||||
usage := responseData.Usage
|
||||
if usage.PromptTokens == 0 {
|
||||
usage.PromptTokens = usage.InputTokens
|
||||
}
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = usage.OutputTokens
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
}
|
||||
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
|
||||
}
|
||||
usage := &dto.Usage{}
|
||||
usage.PromptTokens = audioTokens
|
||||
usage.CompletionTokens = 0
|
||||
|
||||
@@ -46,9 +46,17 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
|
||||
usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil {
|
||||
return &usage, nil
|
||||
}
|
||||
// 解析 Tools 用量
|
||||
for _, tool := range responsesResponse.Tools {
|
||||
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
|
||||
buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])]
|
||||
if !ok || buildToolinfo == nil {
|
||||
logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"]))
|
||||
continue
|
||||
}
|
||||
buildToolinfo.CallCount++
|
||||
}
|
||||
return &usage, nil
|
||||
}
|
||||
@@ -72,7 +80,7 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
||||
sendResponsesStreamData(c, streamResponse, data)
|
||||
switch streamResponse.Type {
|
||||
case "response.completed":
|
||||
if streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response != nil && streamResponse.Response.Usage != nil {
|
||||
if streamResponse.Response.Usage.InputTokens != 0 {
|
||||
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
@@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
|
||||
// Handle one-of image_urls or binary_data_base64
|
||||
if req.Image != "" {
|
||||
if strings.HasPrefix(req.Image, "http") {
|
||||
r.ImageUrls = []string{req.Image}
|
||||
if req.HasImage() {
|
||||
if strings.HasPrefix(req.Images[0], "http") {
|
||||
r.ImageUrls = req.Images
|
||||
} else {
|
||||
r.BinaryDataBase64 = []string{req.Image}
|
||||
r.BinaryDataBase64 = req.Images
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -28,16 +27,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TrajectoryPoint struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
|
||||
355
relay/channel/task/vertex/adaptor.go
Normal file
355
relay/channel/task/vertex/adaptor.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
vertexcore "one-api/relay/channel/vertex"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type requestPayload struct {
|
||||
Instances []map[string]any `json:"instances"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type submitResponse struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type operationVideo struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
type operationResponse struct {
|
||||
Name string `json:"name"`
|
||||
Done bool `json:"done"`
|
||||
Response struct {
|
||||
Type string `json:"@type"`
|
||||
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
|
||||
Videos []operationVideo `json:"videos"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
Video string `json:"video"`
|
||||
} `json:"response"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
if modelName == "" {
|
||||
modelName = "veo-3.0-generate-001"
|
||||
}
|
||||
|
||||
region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "global"
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Vertex specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body := requestPayload{
|
||||
Instances: []map[string]any{{"prompt": req.Prompt}},
|
||||
Parameters: map[string]any{},
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
if v, ok := req.Metadata["storageUri"]; ok {
|
||||
body.Parameters["storageUri"] = v
|
||||
}
|
||||
if v, ok := req.Metadata["sampleCount"]; ok {
|
||||
body.Parameters["sampleCount"] = v
|
||||
}
|
||||
}
|
||||
if _, ok := body.Parameters["sampleCount"]; !ok {
|
||||
body.Parameters["sampleCount"] = 1
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
localID := encodeLocalTaskID(s.Name)
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||
return localID, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if project == "" || modelName == "" {
|
||||
return nil, fmt.Errorf("cannot extract project/model from operation name")
|
||||
}
|
||||
var url string
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
}
|
||||
payload := map[string]string{"operationName": upstreamName}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
ti := &relaycommon.TaskInfo{}
|
||||
if op.Error.Message != "" {
|
||||
ti.Status = model.TaskStatusFailure
|
||||
ti.Reason = op.Error.Message
|
||||
ti.Progress = "100%"
|
||||
return ti, nil
|
||||
}
|
||||
if !op.Done {
|
||||
ti.Status = model.TaskStatusInProgress
|
||||
ti.Progress = "50%"
|
||||
return ti, nil
|
||||
}
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
ti.Progress = "100%"
|
||||
if len(op.Response.Videos) > 0 {
|
||||
v0 := op.Response.Videos[0]
|
||||
if v0.BytesBase64Encoded != "" {
|
||||
mime := strings.TrimSpace(v0.MimeType)
|
||||
if mime == "" {
|
||||
enc := strings.TrimSpace(v0.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
if strings.Contains(enc, "/") {
|
||||
mime = enc
|
||||
} else {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
}
|
||||
if op.Response.BytesBase64Encoded != "" {
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
if op.Response.Video != "" { // some variants use `video` as base64
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.Video
|
||||
return ti, nil
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||
|
||||
func extractRegionFromOperationName(name string) string {
|
||||
m := regionRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
|
||||
|
||||
func extractModelFromOperationName(name string) string {
|
||||
m := modelRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
idx := strings.Index(name, "models/")
|
||||
if idx >= 0 {
|
||||
s := name[idx+len("models/"):]
|
||||
if p := strings.Index(s, "/operations/"); p > 0 {
|
||||
return s[:p]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
|
||||
|
||||
func extractProjectFromOperationName(name string) string {
|
||||
m := projectRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -23,16 +23,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Image != "" {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else {
|
||||
info.Action = constant.TaskActionTextGenerate
|
||||
}
|
||||
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
// Use the unified validation method for TaskSubmitReq with image-based action determination
|
||||
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
@@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &Credentials{}
|
||||
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
|
||||
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
|
||||
a.AccountCredentials = *adc
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
adc := &Credentials{}
|
||||
if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
a.AccountCredentials = *adc
|
||||
|
||||
if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
@@ -111,24 +160,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
suffix = "predict"
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, info.UpstreamModelName, suffix)
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if info.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
@@ -139,41 +171,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, model, suffix)
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
return a.getRequestUrl(info, "", "")
|
||||
}
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
if a.AccountCredentials.ProjectID != "" {
|
||||
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string {
|
||||
if m[localModelName] != nil {
|
||||
return m[localModelName].(string)
|
||||
} else {
|
||||
return m["default"].(string)
|
||||
if v, ok := m["default"]; ok {
|
||||
return v.(string)
|
||||
}
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
return other
|
||||
|
||||
@@ -6,14 +6,15 @@ import (
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
@@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
|
||||
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
|
||||
signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||
}
|
||||
return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(proxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
client = service.GetHttpClient()
|
||||
}
|
||||
|
||||
resp, err := client.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if accessToken, ok := result["access_token"].(string); ok {
|
||||
return accessToken, nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -151,7 +153,9 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
|
||||
}
|
||||
|
||||
func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
||||
value := gjson.Get(jsonStr, condition.Path)
|
||||
// 处理负数索引
|
||||
path := processNegativeIndex(jsonStr, condition.Path)
|
||||
value := gjson.Get(jsonStr, path)
|
||||
if !value.Exists() {
|
||||
if condition.PassMissingKey {
|
||||
return true, nil
|
||||
@@ -177,6 +181,37 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func processNegativeIndex(jsonStr string, path string) string {
|
||||
re := regexp.MustCompile(`\.(-\d+)`)
|
||||
matches := re.FindAllStringSubmatch(path, -1)
|
||||
|
||||
if len(matches) == 0 {
|
||||
return path
|
||||
}
|
||||
|
||||
result := path
|
||||
for _, match := range matches {
|
||||
negIndex := match[1]
|
||||
index, _ := strconv.Atoi(negIndex)
|
||||
|
||||
arrayPath := strings.Split(path, negIndex)[0]
|
||||
if strings.HasSuffix(arrayPath, ".") {
|
||||
arrayPath = arrayPath[:len(arrayPath)-1]
|
||||
}
|
||||
|
||||
array := gjson.Get(jsonStr, arrayPath)
|
||||
if array.IsArray() {
|
||||
length := len(array.Array())
|
||||
actualIndex := length + index
|
||||
if actualIndex >= 0 && actualIndex < length {
|
||||
result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
|
||||
func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
|
||||
switch mode {
|
||||
@@ -274,21 +309,25 @@ func applyOperations(jsonStr string, operations []ParamOperation) (string, error
|
||||
if !ok {
|
||||
continue // 条件不满足,跳过当前操作
|
||||
}
|
||||
// 处理路径中的负数索引
|
||||
opPath := processNegativeIndex(result, op.Path)
|
||||
opFrom := processNegativeIndex(result, op.From)
|
||||
opTo := processNegativeIndex(result, op.To)
|
||||
|
||||
switch op.Mode {
|
||||
case "delete":
|
||||
result, err = sjson.Delete(result, op.Path)
|
||||
result, err = sjson.Delete(result, opPath)
|
||||
case "set":
|
||||
if op.KeepOrigin && gjson.Get(result, op.Path).Exists() {
|
||||
if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
|
||||
continue
|
||||
}
|
||||
result, err = sjson.Set(result, op.Path, op.Value)
|
||||
result, err = sjson.Set(result, opPath, op.Value)
|
||||
case "move":
|
||||
result, err = moveValue(result, op.From, op.To)
|
||||
result, err = moveValue(result, opFrom, opTo)
|
||||
case "prepend":
|
||||
result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, true)
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
|
||||
case "append":
|
||||
result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, false)
|
||||
result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown operation: %s", op.Mode)
|
||||
}
|
||||
|
||||
@@ -481,11 +481,20 @@ type TaskSubmitReq struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (t TaskSubmitReq) GetPrompt() string {
|
||||
return t.Prompt
|
||||
}
|
||||
|
||||
func (t TaskSubmitReq) HasImage() bool {
|
||||
return len(t.Images) > 0
|
||||
}
|
||||
|
||||
type TaskInfo struct {
|
||||
Code int `json:"code"`
|
||||
TaskID string `json:"task_id"`
|
||||
|
||||
@@ -2,14 +2,23 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type HasPrompt interface {
|
||||
GetPrompt() string
|
||||
}
|
||||
|
||||
type HasImage interface {
|
||||
HasImage() bool
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
@@ -32,3 +41,72 @@ func GetAPIVersion(c *gin.Context) string {
|
||||
}
|
||||
return apiVersion
|
||||
}
|
||||
|
||||
func createTaskError(err error, code string, statusCode int, localError bool) *dto.TaskError {
|
||||
return &dto.TaskError{
|
||||
Code: code,
|
||||
Message: err.Error(),
|
||||
StatusCode: statusCode,
|
||||
LocalError: localError,
|
||||
Error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
|
||||
info.Action = action
|
||||
c.Set("task_request", requestObj)
|
||||
}
|
||||
|
||||
func validatePrompt(prompt string) *dto.TaskError {
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
return createTaskError(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
|
||||
var req TaskSubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(req.Prompt); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
if len(req.Images) == 0 && strings.TrimSpace(req.Image) != "" {
|
||||
// 兼容单图上传
|
||||
req.Images = []string{req.Image}
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImage(c *gin.Context, info *RelayInfo, requestObj interface{}) *dto.TaskError {
|
||||
hasPrompt, ok := requestObj.(HasPrompt)
|
||||
if !ok {
|
||||
return createTaskError(fmt.Errorf("request must have prompt"), "invalid_request", http.StatusBadRequest, true)
|
||||
}
|
||||
|
||||
if taskErr := validatePrompt(hasPrompt.GetPrompt()); taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
|
||||
action := constant.TaskActionTextGenerate
|
||||
if hasImage, ok := requestObj.(HasImage); ok && hasImage.HasImage() {
|
||||
action = constant.TaskActionGenerate
|
||||
}
|
||||
|
||||
storeTaskRequest(c, info, action, requestObj)
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateTaskRequestWithImageBinding(c *gin.Context, info *RelayInfo) *dto.TaskError {
|
||||
var req TaskSubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return createTaskError(err, "invalid_request_body", http.StatusBadRequest, false)
|
||||
}
|
||||
|
||||
return ValidateTaskRequestWithImage(c, info, req)
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newApiErr := service.RelayErrorHandler(httpResp, false)
|
||||
newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newApiErr, statusCodeMappingStr)
|
||||
return newApiErr
|
||||
@@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
imageTokens := usage.PromptTokensDetails.ImageTokens
|
||||
audioTokens := usage.PromptTokensDetails.AudioTokens
|
||||
completionTokens := usage.CompletionTokens
|
||||
cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
@@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
|
||||
// Convert values to decimal for precise calculation
|
||||
dPromptTokens := decimal.NewFromInt(int64(promptTokens))
|
||||
@@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
dImageTokens := decimal.NewFromInt(int64(imageTokens))
|
||||
dAudioTokens := decimal.NewFromInt(int64(audioTokens))
|
||||
dCompletionTokens := decimal.NewFromInt(int64(completionTokens))
|
||||
dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens))
|
||||
dCompletionRatio := decimal.NewFromFloat(completionRatio)
|
||||
dCacheRatio := decimal.NewFromFloat(cacheRatio)
|
||||
dImageRatio := decimal.NewFromFloat(imageRatio)
|
||||
dModelRatio := decimal.NewFromFloat(modelRatio)
|
||||
dGroupRatio := decimal.NewFromFloat(groupRatio)
|
||||
dModelPrice := decimal.NewFromFloat(modelPrice)
|
||||
dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio)
|
||||
dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit)
|
||||
|
||||
ratio := dModelRatio.Mul(dGroupRatio)
|
||||
@@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
baseTokens = baseTokens.Sub(dCacheTokens)
|
||||
cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio)
|
||||
}
|
||||
var dCachedCreationTokensWithRatio decimal.Decimal
|
||||
if !dCachedCreationTokens.IsZero() {
|
||||
baseTokens = baseTokens.Sub(dCachedCreationTokens)
|
||||
dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio)
|
||||
}
|
||||
|
||||
// 减去 image tokens
|
||||
var imageTokensWithRatio decimal.Decimal
|
||||
@@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio)
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
Add(imageTokensWithRatio).
|
||||
Add(dCachedCreationTokensWithRatio)
|
||||
|
||||
completionQuota := dCompletionTokens.Mul(dCompletionRatio)
|
||||
|
||||
@@ -384,6 +396,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
other["image_ratio"] = imageRatio
|
||||
other["image_output"] = imageTokens
|
||||
}
|
||||
if cachedCreationTokens != 0 {
|
||||
other["cache_creation_tokens"] = cachedCreationTokens
|
||||
other["cache_creation_ratio"] = cachedCreationRatio
|
||||
}
|
||||
if !dWebSearchQuota.IsZero() {
|
||||
if relayInfo.ResponsesUsageInfo != nil {
|
||||
if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists {
|
||||
|
||||
@@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
@@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
httpResp = resp.(*http.Response)
|
||||
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
@@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
var logContent string
|
||||
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality)
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
"one-api/setting"
|
||||
"one-api/setting/system_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = ""
|
||||
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
||||
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||||
if originTask.Status != "SUCCESS" {
|
||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/constant"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/ali"
|
||||
@@ -28,6 +27,7 @@ import (
|
||||
taskjimeng "one-api/relay/channel/task/jimeng"
|
||||
"one-api/relay/channel/task/kling"
|
||||
"one-api/relay/channel/task/suno"
|
||||
taskvertex "one-api/relay/channel/task/vertex"
|
||||
taskVidu "one-api/relay/channel/task/vidu"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/vertex"
|
||||
@@ -37,6 +37,8 @@ import (
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/channel/zhipu_4v"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GetAdaptor(apiType int) channel.Adaptor {
|
||||
@@ -126,6 +128,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &kling.TaskAdaptor{}
|
||||
case constant.ChannelTypeJimeng:
|
||||
return &taskjimeng.TaskAdaptor{}
|
||||
case constant.ChannelTypeVertexAi:
|
||||
return &taskvertex.TaskAdaptor{}
|
||||
case constant.ChannelTypeVidu:
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
relayconstant "one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"one-api/setting/ratio_setting"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -33,6 +35,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
platform = GetTaskPlatform(c)
|
||||
}
|
||||
|
||||
info.InitChannelMeta(c)
|
||||
adaptor := GetTaskAdaptor(platform)
|
||||
if adaptor == nil {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
||||
@@ -197,6 +200,9 @@ func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
||||
if taskErr != nil {
|
||||
return taskErr
|
||||
}
|
||||
if len(respBody) == 0 {
|
||||
respBody = []byte("{\"code\":\"success\",\"data\":null}")
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||||
@@ -276,10 +282,92 @@ func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *d
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
func() {
|
||||
channelModel, err2 := model.GetChannelById(originTask.ChannelId, true)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
if channelModel.Type != constant.ChannelTypeVertexAi {
|
||||
return
|
||||
}
|
||||
baseURL := constant.ChannelBaseURLs[channelModel.Type]
|
||||
if channelModel.GetBaseURL() != "" {
|
||||
baseURL = channelModel.GetBaseURL()
|
||||
}
|
||||
adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type)))
|
||||
if adaptor == nil {
|
||||
return
|
||||
}
|
||||
resp, err2 := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{
|
||||
"task_id": originTask.TaskID,
|
||||
"action": originTask.Action,
|
||||
})
|
||||
if err2 != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err2 := io.ReadAll(resp.Body)
|
||||
if err2 != nil {
|
||||
return
|
||||
}
|
||||
ti, err2 := adaptor.ParseTaskResult(body)
|
||||
if err2 == nil && ti != nil {
|
||||
if ti.Status != "" {
|
||||
originTask.Status = model.TaskStatus(ti.Status)
|
||||
}
|
||||
if ti.Progress != "" {
|
||||
originTask.Progress = ti.Progress
|
||||
}
|
||||
if ti.Url != "" {
|
||||
originTask.FailReason = ti.Url
|
||||
}
|
||||
_ = originTask.Update()
|
||||
var raw map[string]any
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
format := "mp4"
|
||||
if respObj, ok := raw["response"].(map[string]any); ok {
|
||||
if vids, ok := respObj["videos"].([]any); ok && len(vids) > 0 {
|
||||
if v0, ok := vids[0].(map[string]any); ok {
|
||||
if mt, ok := v0["mimeType"].(string); ok && mt != "" {
|
||||
if strings.Contains(mt, "mp4") {
|
||||
format = "mp4"
|
||||
} else {
|
||||
format = mt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
status := "processing"
|
||||
switch originTask.Status {
|
||||
case model.TaskStatusSuccess:
|
||||
status = "succeeded"
|
||||
case model.TaskStatusFailure:
|
||||
status = "failed"
|
||||
case model.TaskStatusQueued, model.TaskStatusSubmitted:
|
||||
status = "queued"
|
||||
}
|
||||
out := map[string]any{
|
||||
"error": nil,
|
||||
"format": format,
|
||||
"metadata": nil,
|
||||
"status": status,
|
||||
"task_id": originTask.TaskID,
|
||||
"url": originTask.FailReason,
|
||||
}
|
||||
respBody, _ = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: out,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
if len(respBody) == 0 {
|
||||
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
||||
Code: "success",
|
||||
Data: TaskModel2Dto(originTask),
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
if resp != nil {
|
||||
httpResp = resp.(*http.Response)
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -41,7 +41,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
}
|
||||
adaptor.Init(info)
|
||||
var requestBody io.Reader
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled {
|
||||
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeReadRequestBodyFailed, types.ErrOptionWithSkipRetry())
|
||||
@@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
httpResp = resp.(*http.Response)
|
||||
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
newAPIError = service.RelayErrorHandler(httpResp, false)
|
||||
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
|
||||
// reset status code 重置状态码
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
|
||||
@@ -60,6 +60,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||
selfRoute.GET("/token", controller.GenerateAccessToken)
|
||||
selfRoute.GET("/aff", controller.GetAffCode)
|
||||
selfRoute.GET("/topup/info", controller.GetTopUpInfo)
|
||||
selfRoute.POST("/topup", middleware.CriticalRateLimit(), controller.TopUp)
|
||||
selfRoute.POST("/pay", middleware.CriticalRateLimit(), controller.RequestEpay)
|
||||
selfRoute.POST("/amount", controller.RequestAmount)
|
||||
@@ -224,6 +225,8 @@ func SetApiRouter(router *gin.Engine) {
|
||||
modelsRoute := apiRouter.Group("/models")
|
||||
modelsRoute.Use(middleware.AdminAuth())
|
||||
{
|
||||
modelsRoute.GET("/sync_upstream/preview", controller.SyncUpstreamPreview)
|
||||
modelsRoute.POST("/sync_upstream", controller.SyncUpstreamModels)
|
||||
modelsRoute.GET("/missing", controller.GetMissingModels)
|
||||
modelsRoute.GET("/", controller.GetAllModelsMeta)
|
||||
modelsRoute.GET("/search", controller.SearchModelsMeta)
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"one-api/setting"
|
||||
"one-api/setting/operation_setting"
|
||||
"one-api/setting/system_setting"
|
||||
)
|
||||
|
||||
func GetCallbackAddress() string {
|
||||
if setting.CustomCallbackAddress == "" {
|
||||
return setting.ServerAddress
|
||||
if operation_setting.CustomCallbackAddress == "" {
|
||||
return system_setting.ServerAddress
|
||||
}
|
||||
return setting.CustomCallbackAddress
|
||||
return operation_setting.CustomCallbackAddress
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
|
||||
return claudeErr
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
|
||||
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
@@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||
} else {
|
||||
if common.DebugEnabled {
|
||||
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||
}
|
||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
|
||||
@@ -13,13 +13,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
|
||||
if preConsumedQuota != 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
if relayInfo.FinalPreConsumedQuota != 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota)))
|
||||
gopool.Go(func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||||
err := PostConsumeQuota(&relayInfoCopy, -relayInfoCopy.FinalPreConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
@@ -29,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr
|
||||
|
||||
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
||||
// It returns the pre-consumed quota if successful, or an error if not.
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
|
||||
trustQuota := common.GetTrustQuota()
|
||||
@@ -65,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
||||
if preConsumedQuota > 0 {
|
||||
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
|
||||
}
|
||||
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
||||
return preConsumedQuota, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/setting/system_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -534,9 +534,28 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
||||
}
|
||||
if quotaTooLow {
|
||||
prompt := "您的额度即将用尽"
|
||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
||||
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
|
||||
topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress)
|
||||
|
||||
// 根据通知方式生成不同的内容格式
|
||||
var content string
|
||||
var values []interface{}
|
||||
|
||||
notifyType := userSetting.NotifyType
|
||||
if notifyType == "" {
|
||||
notifyType = dto.NotifyTypeEmail
|
||||
}
|
||||
|
||||
if notifyType == dto.NotifyTypeBark {
|
||||
// Bark推送使用简短文本,不支持HTML
|
||||
content = "{{value}},剩余额度:{{value}},请及时充值"
|
||||
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)}
|
||||
} else {
|
||||
// 默认内容格式,适用于Email和Webhook
|
||||
content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
||||
values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}
|
||||
}
|
||||
|
||||
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values))
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"math"
|
||||
"one-api/common"
|
||||
@@ -357,33 +360,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||||
// tkm := 0
|
||||
// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// tkm += msgTokens
|
||||
// if request.Tools != nil {
|
||||
// openaiTools := request.Tools
|
||||
// countStr := ""
|
||||
// for _, tool := range openaiTools {
|
||||
// countStr = tool.Function.Name
|
||||
// if tool.Function.Description != "" {
|
||||
// countStr += tool.Function.Description
|
||||
// }
|
||||
// if tool.Function.Parameters != nil {
|
||||
// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
// }
|
||||
// }
|
||||
// toolTokens := CountTokenInput(countStr, request.Model)
|
||||
// tkm += 8
|
||||
// tkm += toolTokens
|
||||
// }
|
||||
//
|
||||
// return tkm, nil
|
||||
//}
|
||||
|
||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
|
||||
@@ -543,56 +519,6 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
return textToken, audioToken, nil
|
||||
}
|
||||
|
||||
//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||||
// //recover when panic
|
||||
// tokenEncoder := getTokenEncoder(model)
|
||||
// // Reference:
|
||||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
// //
|
||||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
// var tokensPerMessage int
|
||||
// var tokensPerName int
|
||||
//
|
||||
// tokensPerMessage = 3
|
||||
// tokensPerName = 1
|
||||
//
|
||||
// tokenNum := 0
|
||||
// for _, message := range messages {
|
||||
// tokenNum += tokensPerMessage
|
||||
// tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
// if message.Content != nil {
|
||||
// if message.Name != nil {
|
||||
// tokenNum += tokensPerName
|
||||
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
// }
|
||||
// arrayContent := message.ParseContent()
|
||||
// for _, m := range arrayContent {
|
||||
// if m.Type == dto.ContentTypeImageURL {
|
||||
// imageUrl := m.GetImageMedia()
|
||||
// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// tokenNum += imageTokenNum
|
||||
// log.Printf("image token num: %d", imageTokenNum)
|
||||
// } else if m.Type == dto.ContentTypeInputAudio {
|
||||
// // TODO: 音频token数量计算
|
||||
// tokenNum += 100
|
||||
// } else if m.Type == dto.ContentTypeFile {
|
||||
// tokenNum += 5000
|
||||
// } else if m.Type == dto.ContentTypeVideoUrl {
|
||||
// tokenNum += 5000
|
||||
// } else {
|
||||
// tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
// return tokenNum, nil
|
||||
//}
|
||||
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
|
||||
@@ -2,9 +2,12 @@ package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -51,6 +54,13 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
|
||||
// 获取 webhook secret
|
||||
webhookSecret := userSetting.WebhookSecret
|
||||
return SendWebhookNotify(webhookURLStr, webhookSecret, data)
|
||||
case dto.NotifyTypeBark:
|
||||
barkURL := userSetting.BarkUrl
|
||||
if barkURL == "" {
|
||||
common.SysLog(fmt.Sprintf("user %d has no bark url, skip sending bark", userId))
|
||||
return nil
|
||||
}
|
||||
return sendBarkNotify(barkURL, data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -64,3 +74,67 @@ func sendEmailNotify(userEmail string, data dto.Notify) error {
|
||||
}
|
||||
return common.SendEmail(data.Title, userEmail, content)
|
||||
}
|
||||
|
||||
func sendBarkNotify(barkURL string, data dto.Notify) error {
|
||||
// 处理占位符
|
||||
content := data.Content
|
||||
for _, value := range data.Values {
|
||||
content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1)
|
||||
}
|
||||
|
||||
// 替换模板变量
|
||||
finalURL := strings.ReplaceAll(barkURL, "{{title}}", url.QueryEscape(data.Title))
|
||||
finalURL = strings.ReplaceAll(finalURL, "{{content}}", url.QueryEscape(content))
|
||||
|
||||
// 发送GET请求到Bark
|
||||
var req *http.Request
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
if setting.EnableWorker() {
|
||||
// 使用worker发送请求
|
||||
workerReq := &WorkerRequest{
|
||||
URL: finalURL,
|
||||
Key: setting.WorkerValidKey,
|
||||
Method: http.MethodGet,
|
||||
Headers: map[string]string{
|
||||
"User-Agent": "OneAPI-Bark-Notify/1.0",
|
||||
},
|
||||
}
|
||||
|
||||
resp, err = DoWorkerRequest(workerReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bark request through worker: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
} else {
|
||||
// 直接发送请求
|
||||
req, err = http.NewRequest(http.MethodGet, finalURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bark request: %v", err)
|
||||
}
|
||||
|
||||
// 设置User-Agent
|
||||
req.Header.Set("User-Agent", "OneAPI-Bark-Notify/1.0")
|
||||
|
||||
// 发送请求
|
||||
client := GetHttpClient()
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send bark request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,37 +3,37 @@ package console_setting
|
||||
import "one-api/setting/config"
|
||||
|
||||
type ConsoleSetting struct {
|
||||
ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
|
||||
UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
|
||||
Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
|
||||
FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
|
||||
ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
|
||||
UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
|
||||
AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
|
||||
FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
|
||||
ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串)
|
||||
UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串)
|
||||
Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串)
|
||||
FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串)
|
||||
ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板
|
||||
UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板
|
||||
AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板
|
||||
FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var defaultConsoleSetting = ConsoleSetting{
|
||||
ApiInfo: "",
|
||||
UptimeKumaGroups: "",
|
||||
Announcements: "",
|
||||
FAQ: "",
|
||||
ApiInfoEnabled: true,
|
||||
UptimeKumaEnabled: true,
|
||||
AnnouncementsEnabled: true,
|
||||
FAQEnabled: true,
|
||||
ApiInfo: "",
|
||||
UptimeKumaGroups: "",
|
||||
Announcements: "",
|
||||
FAQ: "",
|
||||
ApiInfoEnabled: true,
|
||||
UptimeKumaEnabled: true,
|
||||
AnnouncementsEnabled: true,
|
||||
FAQEnabled: true,
|
||||
}
|
||||
|
||||
// 全局实例
|
||||
var consoleSetting = defaultConsoleSetting
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器,键名为 console_setting
|
||||
config.GlobalConfig.Register("console_setting", &consoleSetting)
|
||||
// 注册到全局配置管理器,键名为 console_setting
|
||||
config.GlobalConfig.Register("console_setting", &consoleSetting)
|
||||
}
|
||||
|
||||
// GetConsoleSetting 获取 ConsoleSetting 配置实例
|
||||
func GetConsoleSetting() *ConsoleSetting {
|
||||
return &consoleSetting
|
||||
}
|
||||
return &consoleSetting
|
||||
}
|
||||
|
||||
@@ -1,304 +1,304 @@
|
||||
package console_setting
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"sort"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
|
||||
dangerousChars = []string{"<script", "<iframe", "javascript:", "onload=", "onerror=", "onclick="}
|
||||
validColors = map[string]bool{
|
||||
"blue": true, "green": true, "cyan": true, "purple": true, "pink": true,
|
||||
"red": true, "orange": true, "amber": true, "yellow": true, "lime": true,
|
||||
"light-green": true, "teal": true, "light-blue": true, "indigo": true,
|
||||
"violet": true, "grey": true,
|
||||
}
|
||||
slugRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`)
|
||||
dangerousChars = []string{"<script", "<iframe", "javascript:", "onload=", "onerror=", "onclick="}
|
||||
validColors = map[string]bool{
|
||||
"blue": true, "green": true, "cyan": true, "purple": true, "pink": true,
|
||||
"red": true, "orange": true, "amber": true, "yellow": true, "lime": true,
|
||||
"light-green": true, "teal": true, "light-blue": true, "indigo": true,
|
||||
"violet": true, "grey": true,
|
||||
}
|
||||
slugRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
|
||||
)
|
||||
|
||||
func parseJSONArray(jsonStr string, typeName string) ([]map[string]interface{}, error) {
|
||||
var list []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &list); err != nil {
|
||||
return nil, fmt.Errorf("%s格式错误:%s", typeName, err.Error())
|
||||
}
|
||||
return list, nil
|
||||
var list []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &list); err != nil {
|
||||
return nil, fmt.Errorf("%s格式错误:%s", typeName, err.Error())
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func validateURL(urlStr string, index int, itemType string) error {
|
||||
if !urlRegex.MatchString(urlStr) {
|
||||
return fmt.Errorf("第%d个%s的URL格式不正确", index, itemType)
|
||||
}
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return fmt.Errorf("第%d个%s的URL无法解析:%s", index, itemType, err.Error())
|
||||
}
|
||||
return nil
|
||||
if !urlRegex.MatchString(urlStr) {
|
||||
return fmt.Errorf("第%d个%s的URL格式不正确", index, itemType)
|
||||
}
|
||||
if _, err := url.Parse(urlStr); err != nil {
|
||||
return fmt.Errorf("第%d个%s的URL无法解析:%s", index, itemType, err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDangerousContent(content string, index int, itemType string) error {
|
||||
lower := strings.ToLower(content)
|
||||
for _, d := range dangerousChars {
|
||||
if strings.Contains(lower, d) {
|
||||
return fmt.Errorf("第%d个%s包含不允许的内容", index, itemType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
lower := strings.ToLower(content)
|
||||
for _, d := range dangerousChars {
|
||||
if strings.Contains(lower, d) {
|
||||
return fmt.Errorf("第%d个%s包含不允许的内容", index, itemType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getJSONList(jsonStr string) []map[string]interface{} {
|
||||
if jsonStr == "" {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
var list []map[string]interface{}
|
||||
json.Unmarshal([]byte(jsonStr), &list)
|
||||
return list
|
||||
if jsonStr == "" {
|
||||
return []map[string]interface{}{}
|
||||
}
|
||||
var list []map[string]interface{}
|
||||
json.Unmarshal([]byte(jsonStr), &list)
|
||||
return list
|
||||
}
|
||||
|
||||
func ValidateConsoleSettings(settingsStr string, settingType string) error {
|
||||
if settingsStr == "" {
|
||||
return nil
|
||||
}
|
||||
if settingsStr == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch settingType {
|
||||
case "ApiInfo":
|
||||
return validateApiInfo(settingsStr)
|
||||
case "Announcements":
|
||||
return validateAnnouncements(settingsStr)
|
||||
case "FAQ":
|
||||
return validateFAQ(settingsStr)
|
||||
case "UptimeKumaGroups":
|
||||
return validateUptimeKumaGroups(settingsStr)
|
||||
default:
|
||||
return fmt.Errorf("未知的设置类型:%s", settingType)
|
||||
}
|
||||
switch settingType {
|
||||
case "ApiInfo":
|
||||
return validateApiInfo(settingsStr)
|
||||
case "Announcements":
|
||||
return validateAnnouncements(settingsStr)
|
||||
case "FAQ":
|
||||
return validateFAQ(settingsStr)
|
||||
case "UptimeKumaGroups":
|
||||
return validateUptimeKumaGroups(settingsStr)
|
||||
default:
|
||||
return fmt.Errorf("未知的设置类型:%s", settingType)
|
||||
}
|
||||
}
|
||||
|
||||
func validateApiInfo(apiInfoStr string) error {
|
||||
apiInfoList, err := parseJSONArray(apiInfoStr, "API信息")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiInfoList, err := parseJSONArray(apiInfoStr, "API信息")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(apiInfoList) > 50 {
|
||||
return fmt.Errorf("API信息数量不能超过50个")
|
||||
}
|
||||
if len(apiInfoList) > 50 {
|
||||
return fmt.Errorf("API信息数量不能超过50个")
|
||||
}
|
||||
|
||||
for i, apiInfo := range apiInfoList {
|
||||
urlStr, ok := apiInfo["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少URL字段", i+1)
|
||||
}
|
||||
route, ok := apiInfo["route"].(string)
|
||||
if !ok || route == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1)
|
||||
}
|
||||
description, ok := apiInfo["description"].(string)
|
||||
if !ok || description == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少说明字段", i+1)
|
||||
}
|
||||
color, ok := apiInfo["color"].(string)
|
||||
if !ok || color == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少颜色字段", i+1)
|
||||
}
|
||||
for i, apiInfo := range apiInfoList {
|
||||
urlStr, ok := apiInfo["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少URL字段", i+1)
|
||||
}
|
||||
route, ok := apiInfo["route"].(string)
|
||||
if !ok || route == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1)
|
||||
}
|
||||
description, ok := apiInfo["description"].(string)
|
||||
if !ok || description == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少说明字段", i+1)
|
||||
}
|
||||
color, ok := apiInfo["color"].(string)
|
||||
if !ok || color == "" {
|
||||
return fmt.Errorf("第%d个API信息缺少颜色字段", i+1)
|
||||
}
|
||||
|
||||
if err := validateURL(urlStr, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateURL(urlStr, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(urlStr) > 500 {
|
||||
return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1)
|
||||
}
|
||||
if len(route) > 100 {
|
||||
return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1)
|
||||
}
|
||||
if len(description) > 200 {
|
||||
return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1)
|
||||
}
|
||||
if len(urlStr) > 500 {
|
||||
return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1)
|
||||
}
|
||||
if len(route) > 100 {
|
||||
return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1)
|
||||
}
|
||||
if len(description) > 200 {
|
||||
return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1)
|
||||
}
|
||||
|
||||
if !validColors[color] {
|
||||
return fmt.Errorf("第%d个API信息的颜色值不合法", i+1)
|
||||
}
|
||||
if !validColors[color] {
|
||||
return fmt.Errorf("第%d个API信息的颜色值不合法", i+1)
|
||||
}
|
||||
|
||||
if err := checkDangerousContent(description, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkDangerousContent(route, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
if err := checkDangerousContent(description, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkDangerousContent(route, i+1, "API信息"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetApiInfo() []map[string]interface{} {
|
||||
return getJSONList(GetConsoleSetting().ApiInfo)
|
||||
return getJSONList(GetConsoleSetting().ApiInfo)
|
||||
}
|
||||
|
||||
func validateAnnouncements(announcementsStr string) error {
|
||||
list, err := parseJSONArray(announcementsStr, "系统公告")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 100 {
|
||||
return fmt.Errorf("系统公告数量不能超过100个")
|
||||
}
|
||||
validTypes := map[string]bool{
|
||||
"default": true, "ongoing": true, "success": true, "warning": true, "error": true,
|
||||
}
|
||||
for i, ann := range list {
|
||||
content, ok := ann["content"].(string)
|
||||
if !ok || content == "" {
|
||||
return fmt.Errorf("第%d个公告缺少内容字段", i+1)
|
||||
}
|
||||
publishDateAny, exists := ann["publishDate"]
|
||||
if !exists {
|
||||
return fmt.Errorf("第%d个公告缺少发布日期字段", i+1)
|
||||
}
|
||||
publishDateStr, ok := publishDateAny.(string)
|
||||
if !ok || publishDateStr == "" {
|
||||
return fmt.Errorf("第%d个公告的发布日期不能为空", i+1)
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil {
|
||||
return fmt.Errorf("第%d个公告的发布日期格式错误", i+1)
|
||||
}
|
||||
if t, exists := ann["type"]; exists {
|
||||
if typeStr, ok := t.(string); ok {
|
||||
if !validTypes[typeStr] {
|
||||
return fmt.Errorf("第%d个公告的类型值不合法", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(content) > 500 {
|
||||
return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1)
|
||||
}
|
||||
if extra, exists := ann["extra"]; exists {
|
||||
if extraStr, ok := extra.(string); ok && len(extraStr) > 200 {
|
||||
return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
list, err := parseJSONArray(announcementsStr, "系统公告")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 100 {
|
||||
return fmt.Errorf("系统公告数量不能超过100个")
|
||||
}
|
||||
validTypes := map[string]bool{
|
||||
"default": true, "ongoing": true, "success": true, "warning": true, "error": true,
|
||||
}
|
||||
for i, ann := range list {
|
||||
content, ok := ann["content"].(string)
|
||||
if !ok || content == "" {
|
||||
return fmt.Errorf("第%d个公告缺少内容字段", i+1)
|
||||
}
|
||||
publishDateAny, exists := ann["publishDate"]
|
||||
if !exists {
|
||||
return fmt.Errorf("第%d个公告缺少发布日期字段", i+1)
|
||||
}
|
||||
publishDateStr, ok := publishDateAny.(string)
|
||||
if !ok || publishDateStr == "" {
|
||||
return fmt.Errorf("第%d个公告的发布日期不能为空", i+1)
|
||||
}
|
||||
if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil {
|
||||
return fmt.Errorf("第%d个公告的发布日期格式错误", i+1)
|
||||
}
|
||||
if t, exists := ann["type"]; exists {
|
||||
if typeStr, ok := t.(string); ok {
|
||||
if !validTypes[typeStr] {
|
||||
return fmt.Errorf("第%d个公告的类型值不合法", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(content) > 500 {
|
||||
return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1)
|
||||
}
|
||||
if extra, exists := ann["extra"]; exists {
|
||||
if extraStr, ok := extra.(string); ok && len(extraStr) > 200 {
|
||||
return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateFAQ(faqStr string) error {
|
||||
list, err := parseJSONArray(faqStr, "FAQ信息")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 100 {
|
||||
return fmt.Errorf("FAQ数量不能超过100个")
|
||||
}
|
||||
for i, faq := range list {
|
||||
question, ok := faq["question"].(string)
|
||||
if !ok || question == "" {
|
||||
return fmt.Errorf("第%d个FAQ缺少问题字段", i+1)
|
||||
}
|
||||
answer, ok := faq["answer"].(string)
|
||||
if !ok || answer == "" {
|
||||
return fmt.Errorf("第%d个FAQ缺少答案字段", i+1)
|
||||
}
|
||||
if len(question) > 200 {
|
||||
return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1)
|
||||
}
|
||||
if len(answer) > 1000 {
|
||||
return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
list, err := parseJSONArray(faqStr, "FAQ信息")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(list) > 100 {
|
||||
return fmt.Errorf("FAQ数量不能超过100个")
|
||||
}
|
||||
for i, faq := range list {
|
||||
question, ok := faq["question"].(string)
|
||||
if !ok || question == "" {
|
||||
return fmt.Errorf("第%d个FAQ缺少问题字段", i+1)
|
||||
}
|
||||
answer, ok := faq["answer"].(string)
|
||||
if !ok || answer == "" {
|
||||
return fmt.Errorf("第%d个FAQ缺少答案字段", i+1)
|
||||
}
|
||||
if len(question) > 200 {
|
||||
return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1)
|
||||
}
|
||||
if len(answer) > 1000 {
|
||||
return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPublishTime(item map[string]interface{}) time.Time {
|
||||
if v, ok := item["publishDate"]; ok {
|
||||
if s, ok2 := v.(string); ok2 {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
if v, ok := item["publishDate"]; ok {
|
||||
if s, ok2 := v.(string); ok2 {
|
||||
if t, err := time.Parse(time.RFC3339, s); err == nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func GetAnnouncements() []map[string]interface{} {
|
||||
list := getJSONList(GetConsoleSetting().Announcements)
|
||||
sort.SliceStable(list, func(i, j int) bool {
|
||||
return getPublishTime(list[i]).After(getPublishTime(list[j]))
|
||||
})
|
||||
return list
|
||||
list := getJSONList(GetConsoleSetting().Announcements)
|
||||
sort.SliceStable(list, func(i, j int) bool {
|
||||
return getPublishTime(list[i]).After(getPublishTime(list[j]))
|
||||
})
|
||||
return list
|
||||
}
|
||||
|
||||
func GetFAQ() []map[string]interface{} {
|
||||
return getJSONList(GetConsoleSetting().FAQ)
|
||||
return getJSONList(GetConsoleSetting().FAQ)
|
||||
}
|
||||
|
||||
func validateUptimeKumaGroups(groupsStr string) error {
|
||||
groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(groups) > 20 {
|
||||
return fmt.Errorf("Uptime Kuma分组数量不能超过20个")
|
||||
}
|
||||
if len(groups) > 20 {
|
||||
return fmt.Errorf("Uptime Kuma分组数量不能超过20个")
|
||||
}
|
||||
|
||||
nameSet := make(map[string]bool)
|
||||
nameSet := make(map[string]bool)
|
||||
|
||||
for i, group := range groups {
|
||||
categoryName, ok := group["categoryName"].(string)
|
||||
if !ok || categoryName == "" {
|
||||
return fmt.Errorf("第%d个分组缺少分类名称字段", i+1)
|
||||
}
|
||||
if nameSet[categoryName] {
|
||||
return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1)
|
||||
}
|
||||
nameSet[categoryName] = true
|
||||
urlStr, ok := group["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("第%d个分组缺少URL字段", i+1)
|
||||
}
|
||||
slug, ok := group["slug"].(string)
|
||||
if !ok || slug == "" {
|
||||
return fmt.Errorf("第%d个分组缺少Slug字段", i+1)
|
||||
}
|
||||
description, ok := group["description"].(string)
|
||||
if !ok {
|
||||
description = ""
|
||||
}
|
||||
for i, group := range groups {
|
||||
categoryName, ok := group["categoryName"].(string)
|
||||
if !ok || categoryName == "" {
|
||||
return fmt.Errorf("第%d个分组缺少分类名称字段", i+1)
|
||||
}
|
||||
if nameSet[categoryName] {
|
||||
return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1)
|
||||
}
|
||||
nameSet[categoryName] = true
|
||||
urlStr, ok := group["url"].(string)
|
||||
if !ok || urlStr == "" {
|
||||
return fmt.Errorf("第%d个分组缺少URL字段", i+1)
|
||||
}
|
||||
slug, ok := group["slug"].(string)
|
||||
if !ok || slug == "" {
|
||||
return fmt.Errorf("第%d个分组缺少Slug字段", i+1)
|
||||
}
|
||||
description, ok := group["description"].(string)
|
||||
if !ok {
|
||||
description = ""
|
||||
}
|
||||
|
||||
if err := validateURL(urlStr, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateURL(urlStr, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(categoryName) > 50 {
|
||||
return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1)
|
||||
}
|
||||
if len(urlStr) > 500 {
|
||||
return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1)
|
||||
}
|
||||
if len(slug) > 100 {
|
||||
return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1)
|
||||
}
|
||||
if len(description) > 200 {
|
||||
return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1)
|
||||
}
|
||||
if len(categoryName) > 50 {
|
||||
return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1)
|
||||
}
|
||||
if len(urlStr) > 500 {
|
||||
return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1)
|
||||
}
|
||||
if len(slug) > 100 {
|
||||
return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1)
|
||||
}
|
||||
if len(description) > 200 {
|
||||
return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1)
|
||||
}
|
||||
|
||||
if !slugRegex.MatchString(slug) {
|
||||
return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1)
|
||||
}
|
||||
if !slugRegex.MatchString(slug) {
|
||||
return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1)
|
||||
}
|
||||
|
||||
if err := checkDangerousContent(description, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
if err := checkDangerousContent(description, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetUptimeKumaGroups() []map[string]interface{} {
|
||||
return getJSONList(GetConsoleSetting().UptimeKumaGroups)
|
||||
}
|
||||
return getJSONList(GetConsoleSetting().UptimeKumaGroups)
|
||||
}
|
||||
|
||||
34
setting/operation_setting/monitor_setting.go
Normal file
34
setting/operation_setting/monitor_setting.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"one-api/setting/config"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type MonitorSetting struct {
|
||||
AutoTestChannelEnabled bool `json:"auto_test_channel_enabled"`
|
||||
AutoTestChannelMinutes int `json:"auto_test_channel_minutes"`
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var monitorSetting = MonitorSetting{
|
||||
AutoTestChannelEnabled: false,
|
||||
AutoTestChannelMinutes: 10,
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("monitor_setting", &monitorSetting)
|
||||
}
|
||||
|
||||
func GetMonitorSetting() *MonitorSetting {
|
||||
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
|
||||
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
|
||||
if err == nil && frequency > 0 {
|
||||
monitorSetting.AutoTestChannelEnabled = true
|
||||
monitorSetting.AutoTestChannelMinutes = frequency
|
||||
}
|
||||
}
|
||||
return &monitorSetting
|
||||
}
|
||||
23
setting/operation_setting/payment_setting.go
Normal file
23
setting/operation_setting/payment_setting.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package operation_setting
|
||||
|
||||
import "one-api/setting/config"
|
||||
|
||||
type PaymentSetting struct {
|
||||
AmountOptions []int `json:"amount_options"`
|
||||
AmountDiscount map[int]float64 `json:"amount_discount"` // 充值金额对应的折扣,例如 100 元 0.9 表示 100 元充值享受 9 折优惠
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
var paymentSetting = PaymentSetting{
|
||||
AmountOptions: []int{10, 20, 50, 100, 200, 500},
|
||||
AmountDiscount: map[int]float64{},
|
||||
}
|
||||
|
||||
func init() {
|
||||
// 注册到全局配置管理器
|
||||
config.GlobalConfig.Register("payment_setting", &paymentSetting)
|
||||
}
|
||||
|
||||
func GetPaymentSetting() *PaymentSetting {
|
||||
return &paymentSetting
|
||||
}
|
||||
@@ -1,6 +1,13 @@
|
||||
package setting
|
||||
/**
|
||||
此文件为旧版支付设置文件,如需增加新的参数、变量等,请在 payment_setting.go 中添加
|
||||
This file is the old version of the payment settings file. If you need to add new parameters, variables, etc., please add them in payment_setting.go
|
||||
*/
|
||||
|
||||
import "encoding/json"
|
||||
package operation_setting
|
||||
|
||||
import (
|
||||
"one-api/common"
|
||||
)
|
||||
|
||||
var PayAddress = ""
|
||||
var CustomCallbackAddress = ""
|
||||
@@ -21,15 +28,21 @@ var PayMethods = []map[string]string{
|
||||
"color": "rgba(var(--semi-green-5), 1)",
|
||||
"type": "wxpay",
|
||||
},
|
||||
{
|
||||
"name": "自定义1",
|
||||
"color": "black",
|
||||
"type": "custom1",
|
||||
"min_topup": "50",
|
||||
},
|
||||
}
|
||||
|
||||
func UpdatePayMethodsByJsonString(jsonString string) error {
|
||||
PayMethods = make([]map[string]string, 0)
|
||||
return json.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
return common.Unmarshal([]byte(jsonString), &PayMethods)
|
||||
}
|
||||
|
||||
func PayMethods2JsonString() string {
|
||||
jsonBytes, err := json.Marshal(PayMethods)
|
||||
jsonBytes, err := common.Marshal(PayMethods)
|
||||
if err != nil {
|
||||
return "[]"
|
||||
}
|
||||
@@ -5,13 +5,13 @@ import "sync/atomic"
|
||||
var exposeRatioEnabled atomic.Bool
|
||||
|
||||
func init() {
|
||||
exposeRatioEnabled.Store(false)
|
||||
exposeRatioEnabled.Store(false)
|
||||
}
|
||||
|
||||
func SetExposeRatioEnabled(enabled bool) {
|
||||
exposeRatioEnabled.Store(enabled)
|
||||
exposeRatioEnabled.Store(enabled)
|
||||
}
|
||||
|
||||
func IsExposeRatioEnabled() bool {
|
||||
return exposeRatioEnabled.Load()
|
||||
}
|
||||
return exposeRatioEnabled.Load()
|
||||
}
|
||||
|
||||
@@ -1,55 +1,55 @@
|
||||
package ratio_setting
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const exposedDataTTL = 30 * time.Second
|
||||
|
||||
type exposedCache struct {
|
||||
data gin.H
|
||||
expiresAt time.Time
|
||||
data gin.H
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
exposedData atomic.Value
|
||||
rebuildMu sync.Mutex
|
||||
exposedData atomic.Value
|
||||
rebuildMu sync.Mutex
|
||||
)
|
||||
|
||||
func InvalidateExposedDataCache() {
|
||||
exposedData.Store((*exposedCache)(nil))
|
||||
exposedData.Store((*exposedCache)(nil))
|
||||
}
|
||||
|
||||
func cloneGinH(src gin.H) gin.H {
|
||||
dst := make(gin.H, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
dst := make(gin.H, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func GetExposedData() gin.H {
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
rebuildMu.Lock()
|
||||
defer rebuildMu.Unlock()
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
newData := gin.H{
|
||||
"model_ratio": GetModelRatioCopy(),
|
||||
"completion_ratio": GetCompletionRatioCopy(),
|
||||
"cache_ratio": GetCacheRatioCopy(),
|
||||
"model_price": GetModelPriceCopy(),
|
||||
}
|
||||
exposedData.Store(&exposedCache{
|
||||
data: newData,
|
||||
expiresAt: time.Now().Add(exposedDataTTL),
|
||||
})
|
||||
return cloneGinH(newData)
|
||||
}
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
rebuildMu.Lock()
|
||||
defer rebuildMu.Unlock()
|
||||
if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) {
|
||||
return cloneGinH(c.data)
|
||||
}
|
||||
newData := gin.H{
|
||||
"model_ratio": GetModelRatioCopy(),
|
||||
"completion_ratio": GetCompletionRatioCopy(),
|
||||
"cache_ratio": GetCacheRatioCopy(),
|
||||
"model_price": GetModelPriceCopy(),
|
||||
}
|
||||
exposedData.Store(&exposedCache{
|
||||
data: newData,
|
||||
expiresAt: time.Now().Add(exposedDataTTL),
|
||||
})
|
||||
return cloneGinH(newData)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package setting
|
||||
package system_setting
|
||||
|
||||
var ServerAddress = "http://localhost:3000"
|
||||
var WorkerUrl = ""
|
||||
@@ -185,6 +185,14 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
|
||||
type NewAPIErrorOptions func(*NewAPIError)
|
||||
|
||||
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
var newErr *NewAPIError
|
||||
// 保留深层传递的 new err
|
||||
if errors.As(err, &newErr) {
|
||||
for _, op := range ops {
|
||||
op(newErr)
|
||||
}
|
||||
return newErr
|
||||
}
|
||||
e := &NewAPIError{
|
||||
Err: err,
|
||||
RelayError: nil,
|
||||
@@ -199,8 +207,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI
|
||||
}
|
||||
|
||||
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError {
|
||||
if errorCode == ErrorCodeDoRequestFailed {
|
||||
err = errors.New("upstream error: do request failed")
|
||||
var newErr *NewAPIError
|
||||
// 保留深层传递的 new err
|
||||
if errors.As(err, &newErr) {
|
||||
if newErr.RelayError == nil {
|
||||
openaiError := OpenAIError{
|
||||
Message: newErr.Error(),
|
||||
Type: string(errorCode),
|
||||
Code: errorCode,
|
||||
}
|
||||
newErr.RelayError = openaiError
|
||||
}
|
||||
for _, op := range ops {
|
||||
op(newErr)
|
||||
}
|
||||
return newErr
|
||||
}
|
||||
openaiError := OpenAIError{
|
||||
Message: err.Error(),
|
||||
@@ -305,6 +326,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions {
|
||||
return func(e *NewAPIError) {
|
||||
if common.DebugEnabled {
|
||||
fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err)
|
||||
}
|
||||
e.Err = errors.New(replaceStr)
|
||||
}
|
||||
}
|
||||
|
||||
func IsRecordErrorLog(e *NewAPIError) bool {
|
||||
if e == nil {
|
||||
return false
|
||||
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React, { lazy, Suspense } from 'react';
|
||||
import React, { lazy, Suspense, useContext, useMemo } from 'react';
|
||||
import { Route, Routes, useLocation } from 'react-router-dom';
|
||||
import Loading from './components/common/ui/Loading';
|
||||
import User from './pages/User';
|
||||
@@ -27,6 +27,7 @@ import LoginForm from './components/auth/LoginForm';
|
||||
import NotFound from './pages/NotFound';
|
||||
import Forbidden from './pages/Forbidden';
|
||||
import Setting from './pages/Setting';
|
||||
import { StatusContext } from './context/Status';
|
||||
|
||||
import PasswordResetForm from './components/auth/PasswordResetForm';
|
||||
import PasswordResetConfirm from './components/auth/PasswordResetConfirm';
|
||||
@@ -53,6 +54,29 @@ const About = lazy(() => import('./pages/About'));
|
||||
|
||||
function App() {
|
||||
const location = useLocation();
|
||||
const [statusState] = useContext(StatusContext);
|
||||
|
||||
// 获取模型广场权限配置
|
||||
const pricingRequireAuth = useMemo(() => {
|
||||
const headerNavModulesConfig = statusState?.status?.HeaderNavModules;
|
||||
if (headerNavModulesConfig) {
|
||||
try {
|
||||
const modules = JSON.parse(headerNavModulesConfig);
|
||||
|
||||
// 处理向后兼容性:如果pricing是boolean,默认不需要登录
|
||||
if (typeof modules.pricing === 'boolean') {
|
||||
return false; // 默认不需要登录鉴权
|
||||
}
|
||||
|
||||
// 如果是对象格式,使用requireAuth配置
|
||||
return modules.pricing?.requireAuth === true;
|
||||
} catch (error) {
|
||||
console.error('解析顶栏模块配置失败:', error);
|
||||
return false; // 默认不需要登录
|
||||
}
|
||||
}
|
||||
return false; // 默认不需要登录
|
||||
}, [statusState?.status?.HeaderNavModules]);
|
||||
|
||||
return (
|
||||
<SetupCheck>
|
||||
@@ -253,9 +277,20 @@ function App() {
|
||||
<Route
|
||||
path='/pricing'
|
||||
element={
|
||||
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
||||
<Pricing />
|
||||
</Suspense>
|
||||
pricingRequireAuth ? (
|
||||
<PrivateRoute>
|
||||
<Suspense
|
||||
fallback={<Loading></Loading>}
|
||||
key={location.pathname}
|
||||
>
|
||||
<Pricing />
|
||||
</Suspense>
|
||||
</PrivateRoute>
|
||||
) : (
|
||||
<Suspense fallback={<Loading></Loading>} key={location.pathname}>
|
||||
<Pricing />
|
||||
</Suspense>
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Route
|
||||
|
||||
@@ -135,7 +135,7 @@ const TwoFactorAuthModal = ({
|
||||
autoFocus
|
||||
/>
|
||||
<Typography.Text type='tertiary' size='small' className='mt-2 block'>
|
||||
{t('支持6位TOTP验证码或8位备用码')}
|
||||
{t('支持6位TOTP验证码或8位备用码,可到`个人设置-安全设置-两步验证设置`配置或查看。')}
|
||||
</Typography.Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -443,7 +443,7 @@ const JSONEditor = ({
|
||||
|
||||
return (
|
||||
<Row key={pair.id} gutter={8} align='middle'>
|
||||
<Col span={6}>
|
||||
<Col span={10}>
|
||||
<div className='relative'>
|
||||
<Input
|
||||
placeholder={t('键名')}
|
||||
@@ -470,7 +470,7 @@ const JSONEditor = ({
|
||||
)}
|
||||
</div>
|
||||
</Col>
|
||||
<Col span={16}>{renderValueInput(pair.id, pair.value)}</Col>
|
||||
<Col span={12}>{renderValueInput(pair.id, pair.value)}</Col>
|
||||
<Col span={2}>
|
||||
<Button
|
||||
icon={<IconDelete />}
|
||||
|
||||
@@ -100,7 +100,7 @@ const ApiInfoPanel = ({
|
||||
</React.Fragment>
|
||||
))
|
||||
) : (
|
||||
<div className='flex justify-center items-center py-8'>
|
||||
<div className='flex justify-center items-center min-h-[20rem] w-full'>
|
||||
<Empty
|
||||
image={<IllustrationConstruction style={ILLUSTRATION_SIZE} />}
|
||||
darkModeImage={
|
||||
|
||||
@@ -20,11 +20,6 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
import React from 'react';
|
||||
import { Card, Tabs, TabPane } from '@douyinfe/semi-ui';
|
||||
import { PieChart } from 'lucide-react';
|
||||
import {
|
||||
IconHistogram,
|
||||
IconPulse,
|
||||
IconPieChart2Stroked,
|
||||
} from '@douyinfe/semi-icons';
|
||||
import { VChart } from '@visactor/react-vchart';
|
||||
|
||||
const ChartsPanel = ({
|
||||
@@ -51,46 +46,14 @@ const ChartsPanel = ({
|
||||
{t('模型数据分析')}
|
||||
</div>
|
||||
<Tabs
|
||||
type='button'
|
||||
type='slash'
|
||||
activeKey={activeChartTab}
|
||||
onChange={setActiveChartTab}
|
||||
>
|
||||
<TabPane
|
||||
tab={
|
||||
<span>
|
||||
<IconHistogram />
|
||||
{t('消耗分布')}
|
||||
</span>
|
||||
}
|
||||
itemKey='1'
|
||||
/>
|
||||
<TabPane
|
||||
tab={
|
||||
<span>
|
||||
<IconPulse />
|
||||
{t('消耗趋势')}
|
||||
</span>
|
||||
}
|
||||
itemKey='2'
|
||||
/>
|
||||
<TabPane
|
||||
tab={
|
||||
<span>
|
||||
<IconPieChart2Stroked />
|
||||
{t('调用次数分布')}
|
||||
</span>
|
||||
}
|
||||
itemKey='3'
|
||||
/>
|
||||
<TabPane
|
||||
tab={
|
||||
<span>
|
||||
<IconHistogram />
|
||||
{t('调用次数排行')}
|
||||
</span>
|
||||
}
|
||||
itemKey='4'
|
||||
/>
|
||||
<TabPane tab={<span>{t('消耗分布')}</span>} itemKey='1' />
|
||||
<TabPane tab={<span>{t('消耗趋势')}</span>} itemKey='2' />
|
||||
<TabPane tab={<span>{t('调用次数分布')}</span>} itemKey='3' />
|
||||
<TabPane tab={<span>{t('调用次数排行')}</span>} itemKey='4' />
|
||||
</Tabs>
|
||||
</div>
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
export { default } from './HeaderBar/index';
|
||||
@@ -1,148 +0,0 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Skeleton } from '@douyinfe/semi-ui';
|
||||
|
||||
const SkeletonWrapper = ({
|
||||
loading = false,
|
||||
type = 'text',
|
||||
count = 1,
|
||||
width = 60,
|
||||
height = 16,
|
||||
isMobile = false,
|
||||
className = '',
|
||||
children,
|
||||
...props
|
||||
}) => {
|
||||
if (!loading) {
|
||||
return children;
|
||||
}
|
||||
|
||||
// 导航链接骨架屏
|
||||
const renderNavigationSkeleton = () => {
|
||||
const skeletonLinkClasses = isMobile
|
||||
? 'flex items-center gap-1 p-1 w-full rounded-md'
|
||||
: 'flex items-center gap-1 p-2 rounded-md';
|
||||
|
||||
return Array(count)
|
||||
.fill(null)
|
||||
.map((_, index) => (
|
||||
<div key={index} className={skeletonLinkClasses}>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: isMobile ? 40 : width, height }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
));
|
||||
};
|
||||
|
||||
// 用户区域骨架屏 (头像 + 文本)
|
||||
const renderUserAreaSkeleton = () => {
|
||||
return (
|
||||
<div
|
||||
className={`flex items-center p-1 rounded-full bg-semi-color-fill-0 dark:bg-semi-color-fill-1 ${className}`}
|
||||
>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar active size='extra-small' className='shadow-sm' />
|
||||
}
|
||||
/>
|
||||
<div className='ml-1.5 mr-1'>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: isMobile ? 15 : width, height: 12 }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// Logo图片骨架屏
|
||||
const renderImageSkeleton = () => {
|
||||
return (
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Image
|
||||
active
|
||||
className={`absolute inset-0 !rounded-full ${className}`}
|
||||
style={{ width: '100%', height: '100%' }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
// 系统名称骨架屏
|
||||
const renderTitleSkeleton = () => {
|
||||
return (
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={<Skeleton.Title active style={{ width, height: 24 }} />}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
// 通用文本骨架屏
|
||||
const renderTextSkeleton = () => {
|
||||
return (
|
||||
<div className={className}>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={<Skeleton.Title active style={{ width, height }} />}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 根据类型渲染不同的骨架屏
|
||||
switch (type) {
|
||||
case 'navigation':
|
||||
return renderNavigationSkeleton();
|
||||
case 'userArea':
|
||||
return renderUserAreaSkeleton();
|
||||
case 'image':
|
||||
return renderImageSkeleton();
|
||||
case 'title':
|
||||
return renderTitleSkeleton();
|
||||
case 'text':
|
||||
default:
|
||||
return renderTextSkeleton();
|
||||
}
|
||||
};
|
||||
|
||||
export default SkeletonWrapper;
|
||||
@@ -17,7 +17,7 @@ along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import HeaderBar from './HeaderBar';
|
||||
import HeaderBar from './headerbar';
|
||||
import { Layout } from '@douyinfe/semi-ui';
|
||||
import SiderBar from './SiderBar';
|
||||
import App from '../../App';
|
||||
|
||||
@@ -23,7 +23,10 @@ import { useTranslation } from 'react-i18next';
|
||||
import { getLucideIcon } from '../../helpers/render';
|
||||
import { ChevronLeft } from 'lucide-react';
|
||||
import { useSidebarCollapsed } from '../../hooks/common/useSidebarCollapsed';
|
||||
import { useSidebar } from '../../hooks/common/useSidebar';
|
||||
import { useMinimumLoadingTime } from '../../hooks/common/useMinimumLoadingTime';
|
||||
import { isAdmin, isRoot, showError } from '../../helpers';
|
||||
import SkeletonWrapper from './components/SkeletonWrapper';
|
||||
|
||||
import { Nav, Divider, Button } from '@douyinfe/semi-ui';
|
||||
|
||||
@@ -49,6 +52,13 @@ const routerMap = {
|
||||
const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
const { t } = useTranslation();
|
||||
const [collapsed, toggleCollapsed] = useSidebarCollapsed();
|
||||
const {
|
||||
isModuleVisible,
|
||||
hasSectionVisibleModules,
|
||||
loading: sidebarLoading,
|
||||
} = useSidebar();
|
||||
|
||||
const showSkeleton = useMinimumLoadingTime(sidebarLoading);
|
||||
|
||||
const [selectedKeys, setSelectedKeys] = useState(['home']);
|
||||
const [chatItems, setChatItems] = useState([]);
|
||||
@@ -56,8 +66,8 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
const location = useLocation();
|
||||
const [routerMapState, setRouterMapState] = useState(routerMap);
|
||||
|
||||
const workspaceItems = useMemo(
|
||||
() => [
|
||||
const workspaceItems = useMemo(() => {
|
||||
const items = [
|
||||
{
|
||||
text: t('数据看板'),
|
||||
itemKey: 'detail',
|
||||
@@ -93,17 +103,25 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
className:
|
||||
localStorage.getItem('enable_task') === 'true' ? '' : 'tableHiddle',
|
||||
},
|
||||
],
|
||||
[
|
||||
localStorage.getItem('enable_data_export'),
|
||||
localStorage.getItem('enable_drawing'),
|
||||
localStorage.getItem('enable_task'),
|
||||
t,
|
||||
],
|
||||
);
|
||||
];
|
||||
|
||||
const financeItems = useMemo(
|
||||
() => [
|
||||
// 根据配置过滤项目
|
||||
const filteredItems = items.filter((item) => {
|
||||
const configVisible = isModuleVisible('console', item.itemKey);
|
||||
return configVisible;
|
||||
});
|
||||
|
||||
return filteredItems;
|
||||
}, [
|
||||
localStorage.getItem('enable_data_export'),
|
||||
localStorage.getItem('enable_drawing'),
|
||||
localStorage.getItem('enable_task'),
|
||||
t,
|
||||
isModuleVisible,
|
||||
]);
|
||||
|
||||
const financeItems = useMemo(() => {
|
||||
const items = [
|
||||
{
|
||||
text: t('钱包管理'),
|
||||
itemKey: 'topup',
|
||||
@@ -114,12 +132,19 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
itemKey: 'personal',
|
||||
to: '/personal',
|
||||
},
|
||||
],
|
||||
[t],
|
||||
);
|
||||
];
|
||||
|
||||
const adminItems = useMemo(
|
||||
() => [
|
||||
// 根据配置过滤项目
|
||||
const filteredItems = items.filter((item) => {
|
||||
const configVisible = isModuleVisible('personal', item.itemKey);
|
||||
return configVisible;
|
||||
});
|
||||
|
||||
return filteredItems;
|
||||
}, [t, isModuleVisible]);
|
||||
|
||||
const adminItems = useMemo(() => {
|
||||
const items = [
|
||||
{
|
||||
text: t('渠道管理'),
|
||||
itemKey: 'channel',
|
||||
@@ -150,12 +175,19 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
to: '/setting',
|
||||
className: isRoot() ? '' : 'tableHiddle',
|
||||
},
|
||||
],
|
||||
[isAdmin(), isRoot(), t],
|
||||
);
|
||||
];
|
||||
|
||||
const chatMenuItems = useMemo(
|
||||
() => [
|
||||
// 根据配置过滤项目
|
||||
const filteredItems = items.filter((item) => {
|
||||
const configVisible = isModuleVisible('admin', item.itemKey);
|
||||
return configVisible;
|
||||
});
|
||||
|
||||
return filteredItems;
|
||||
}, [isAdmin(), isRoot(), t, isModuleVisible]);
|
||||
|
||||
const chatMenuItems = useMemo(() => {
|
||||
const items = [
|
||||
{
|
||||
text: t('操练场'),
|
||||
itemKey: 'playground',
|
||||
@@ -166,9 +198,16 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
itemKey: 'chat',
|
||||
items: chatItems,
|
||||
},
|
||||
],
|
||||
[chatItems, t],
|
||||
);
|
||||
];
|
||||
|
||||
// 根据配置过滤项目
|
||||
const filteredItems = items.filter((item) => {
|
||||
const configVisible = isModuleVisible('chat', item.itemKey);
|
||||
return configVisible;
|
||||
});
|
||||
|
||||
return filteredItems;
|
||||
}, [chatItems, t, isModuleVisible]);
|
||||
|
||||
// 更新路由映射,添加聊天路由
|
||||
const updateRouterMapWithChats = (chats) => {
|
||||
@@ -213,7 +252,6 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
updateRouterMapWithChats(chats);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
showError('聊天数据解析失败');
|
||||
}
|
||||
}
|
||||
@@ -267,14 +305,12 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
key={item.itemKey}
|
||||
itemKey={item.itemKey}
|
||||
text={
|
||||
<div className='flex items-center'>
|
||||
<span
|
||||
className='truncate font-medium text-sm'
|
||||
style={{ color: textColor }}
|
||||
>
|
||||
{item.text}
|
||||
</span>
|
||||
</div>
|
||||
<span
|
||||
className='truncate font-medium text-sm'
|
||||
style={{ color: textColor }}
|
||||
>
|
||||
{item.text}
|
||||
</span>
|
||||
}
|
||||
icon={
|
||||
<div className='sidebar-icon-container flex-shrink-0'>
|
||||
@@ -297,14 +333,12 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
key={item.itemKey}
|
||||
itemKey={item.itemKey}
|
||||
text={
|
||||
<div className='flex items-center'>
|
||||
<span
|
||||
className='truncate font-medium text-sm'
|
||||
style={{ color: textColor }}
|
||||
>
|
||||
{item.text}
|
||||
</span>
|
||||
</div>
|
||||
<span
|
||||
className='truncate font-medium text-sm'
|
||||
style={{ color: textColor }}
|
||||
>
|
||||
{item.text}
|
||||
</span>
|
||||
}
|
||||
icon={
|
||||
<div className='sidebar-icon-container flex-shrink-0'>
|
||||
@@ -341,110 +375,142 @@ const SiderBar = ({ onNavigate = () => {} }) => {
|
||||
return (
|
||||
<div
|
||||
className='sidebar-container'
|
||||
style={{ width: 'var(--sidebar-current-width)' }}
|
||||
style={{
|
||||
width: 'var(--sidebar-current-width)',
|
||||
background: 'var(--semi-color-bg-0)',
|
||||
}}
|
||||
>
|
||||
<Nav
|
||||
className='sidebar-nav'
|
||||
defaultIsCollapsed={collapsed}
|
||||
isCollapsed={collapsed}
|
||||
onCollapseChange={toggleCollapsed}
|
||||
selectedKeys={selectedKeys}
|
||||
itemStyle='sidebar-nav-item'
|
||||
hoverStyle='sidebar-nav-item:hover'
|
||||
selectedStyle='sidebar-nav-item-selected'
|
||||
renderWrapper={({ itemElement, props }) => {
|
||||
const to = routerMapState[props.itemKey] || routerMap[props.itemKey];
|
||||
|
||||
// 如果没有路由,直接返回元素
|
||||
if (!to) return itemElement;
|
||||
|
||||
return (
|
||||
<Link
|
||||
style={{ textDecoration: 'none' }}
|
||||
to={to}
|
||||
onClick={onNavigate}
|
||||
>
|
||||
{itemElement}
|
||||
</Link>
|
||||
);
|
||||
}}
|
||||
onSelect={(key) => {
|
||||
// 如果点击的是已经展开的子菜单的父项,则收起子菜单
|
||||
if (openedKeys.includes(key.itemKey)) {
|
||||
setOpenedKeys(openedKeys.filter((k) => k !== key.itemKey));
|
||||
}
|
||||
|
||||
setSelectedKeys([key.itemKey]);
|
||||
}}
|
||||
openKeys={openedKeys}
|
||||
onOpenChange={(data) => {
|
||||
setOpenedKeys(data.openKeys);
|
||||
}}
|
||||
<SkeletonWrapper
|
||||
loading={showSkeleton}
|
||||
type='sidebar'
|
||||
className=''
|
||||
collapsed={collapsed}
|
||||
showAdmin={isAdmin()}
|
||||
>
|
||||
{/* 聊天区域 */}
|
||||
<div className='sidebar-section'>
|
||||
{!collapsed && <div className='sidebar-group-label'>{t('聊天')}</div>}
|
||||
{chatMenuItems.map((item) => renderSubItem(item))}
|
||||
</div>
|
||||
<Nav
|
||||
className='sidebar-nav'
|
||||
defaultIsCollapsed={collapsed}
|
||||
isCollapsed={collapsed}
|
||||
onCollapseChange={toggleCollapsed}
|
||||
selectedKeys={selectedKeys}
|
||||
itemStyle='sidebar-nav-item'
|
||||
hoverStyle='sidebar-nav-item:hover'
|
||||
selectedStyle='sidebar-nav-item-selected'
|
||||
renderWrapper={({ itemElement, props }) => {
|
||||
const to =
|
||||
routerMapState[props.itemKey] || routerMap[props.itemKey];
|
||||
|
||||
{/* 控制台区域 */}
|
||||
<Divider className='sidebar-divider' />
|
||||
<div>
|
||||
{!collapsed && (
|
||||
<div className='sidebar-group-label'>{t('控制台')}</div>
|
||||
)}
|
||||
{workspaceItems.map((item) => renderNavItem(item))}
|
||||
</div>
|
||||
// 如果没有路由,直接返回元素
|
||||
if (!to) return itemElement;
|
||||
|
||||
{/* 个人中心区域 */}
|
||||
<Divider className='sidebar-divider' />
|
||||
<div>
|
||||
{!collapsed && (
|
||||
<div className='sidebar-group-label'>{t('个人中心')}</div>
|
||||
)}
|
||||
{financeItems.map((item) => renderNavItem(item))}
|
||||
</div>
|
||||
return (
|
||||
<Link
|
||||
style={{ textDecoration: 'none' }}
|
||||
to={to}
|
||||
onClick={onNavigate}
|
||||
>
|
||||
{itemElement}
|
||||
</Link>
|
||||
);
|
||||
}}
|
||||
onSelect={(key) => {
|
||||
// 如果点击的是已经展开的子菜单的父项,则收起子菜单
|
||||
if (openedKeys.includes(key.itemKey)) {
|
||||
setOpenedKeys(openedKeys.filter((k) => k !== key.itemKey));
|
||||
}
|
||||
|
||||
{/* 管理员区域 - 只在管理员时显示 */}
|
||||
{isAdmin() && (
|
||||
<>
|
||||
<Divider className='sidebar-divider' />
|
||||
<div>
|
||||
setSelectedKeys([key.itemKey]);
|
||||
}}
|
||||
openKeys={openedKeys}
|
||||
onOpenChange={(data) => {
|
||||
setOpenedKeys(data.openKeys);
|
||||
}}
|
||||
>
|
||||
{/* 聊天区域 */}
|
||||
{hasSectionVisibleModules('chat') && (
|
||||
<div className='sidebar-section'>
|
||||
{!collapsed && (
|
||||
<div className='sidebar-group-label'>{t('管理员')}</div>
|
||||
<div className='sidebar-group-label'>{t('聊天')}</div>
|
||||
)}
|
||||
{adminItems.map((item) => renderNavItem(item))}
|
||||
{chatMenuItems.map((item) => renderSubItem(item))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</Nav>
|
||||
)}
|
||||
|
||||
{/* 控制台区域 */}
|
||||
{hasSectionVisibleModules('console') && (
|
||||
<>
|
||||
<Divider className='sidebar-divider' />
|
||||
<div>
|
||||
{!collapsed && (
|
||||
<div className='sidebar-group-label'>{t('控制台')}</div>
|
||||
)}
|
||||
{workspaceItems.map((item) => renderNavItem(item))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* 个人中心区域 */}
|
||||
{hasSectionVisibleModules('personal') && (
|
||||
<>
|
||||
<Divider className='sidebar-divider' />
|
||||
<div>
|
||||
{!collapsed && (
|
||||
<div className='sidebar-group-label'>{t('个人中心')}</div>
|
||||
)}
|
||||
{financeItems.map((item) => renderNavItem(item))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* 管理员区域 - 只在管理员时显示且配置允许时显示 */}
|
||||
{isAdmin() && hasSectionVisibleModules('admin') && (
|
||||
<>
|
||||
<Divider className='sidebar-divider' />
|
||||
<div>
|
||||
{!collapsed && (
|
||||
<div className='sidebar-group-label'>{t('管理员')}</div>
|
||||
)}
|
||||
{adminItems.map((item) => renderNavItem(item))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</Nav>
|
||||
</SkeletonWrapper>
|
||||
|
||||
{/* 底部折叠按钮 */}
|
||||
<div className='sidebar-collapse-button'>
|
||||
<Button
|
||||
theme='outline'
|
||||
type='tertiary'
|
||||
size='small'
|
||||
icon={
|
||||
<ChevronLeft
|
||||
size={16}
|
||||
strokeWidth={2.5}
|
||||
color='var(--semi-color-text-2)'
|
||||
style={{
|
||||
transform: collapsed ? 'rotate(180deg)' : 'rotate(0deg)',
|
||||
}}
|
||||
/>
|
||||
}
|
||||
onClick={toggleCollapsed}
|
||||
icononly={collapsed}
|
||||
style={
|
||||
collapsed
|
||||
? { padding: '4px', width: '100%' }
|
||||
: { padding: '4px 12px', width: '100%' }
|
||||
}
|
||||
<SkeletonWrapper
|
||||
loading={showSkeleton}
|
||||
type='button'
|
||||
width={collapsed ? 36 : 156}
|
||||
height={24}
|
||||
className='w-full'
|
||||
>
|
||||
{!collapsed ? t('收起侧边栏') : null}
|
||||
</Button>
|
||||
<Button
|
||||
theme='outline'
|
||||
type='tertiary'
|
||||
size='small'
|
||||
icon={
|
||||
<ChevronLeft
|
||||
size={16}
|
||||
strokeWidth={2.5}
|
||||
color='var(--semi-color-text-2)'
|
||||
style={{
|
||||
transform: collapsed ? 'rotate(180deg)' : 'rotate(0deg)',
|
||||
}}
|
||||
/>
|
||||
}
|
||||
onClick={toggleCollapsed}
|
||||
icononly={collapsed}
|
||||
style={
|
||||
collapsed
|
||||
? { width: 36, height: 24, padding: 0 }
|
||||
: { padding: '4px 12px', width: '100%' }
|
||||
}
|
||||
>
|
||||
{!collapsed ? t('收起侧边栏') : null}
|
||||
</Button>
|
||||
</SkeletonWrapper>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
394
web/src/components/layout/components/SkeletonWrapper.jsx
Normal file
394
web/src/components/layout/components/SkeletonWrapper.jsx
Normal file
@@ -0,0 +1,394 @@
|
||||
/*
|
||||
Copyright (C) 2025 QuantumNous
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as
|
||||
published by the Free Software Foundation, either version 3 of the
|
||||
License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
For commercial licensing, please contact support@quantumnous.com
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { Skeleton } from '@douyinfe/semi-ui';
|
||||
|
||||
const SkeletonWrapper = ({
|
||||
loading = false,
|
||||
type = 'text',
|
||||
count = 1,
|
||||
width = 60,
|
||||
height = 16,
|
||||
isMobile = false,
|
||||
className = '',
|
||||
collapsed = false,
|
||||
showAdmin = true,
|
||||
children,
|
||||
...props
|
||||
}) => {
|
||||
if (!loading) {
|
||||
return children;
|
||||
}
|
||||
|
||||
// 导航链接骨架屏
|
||||
const renderNavigationSkeleton = () => {
|
||||
const skeletonLinkClasses = isMobile
|
||||
? 'flex items-center gap-1 p-1 w-full rounded-md'
|
||||
: 'flex items-center gap-1 p-2 rounded-md';
|
||||
|
||||
return Array(count)
|
||||
.fill(null)
|
||||
.map((_, index) => (
|
||||
<div key={index} className={skeletonLinkClasses}>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: isMobile ? 40 : width, height }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
));
|
||||
};
|
||||
|
||||
// 用户区域骨架屏 (头像 + 文本)
|
||||
const renderUserAreaSkeleton = () => {
|
||||
return (
|
||||
<div
|
||||
className={`flex items-center p-1 rounded-full bg-semi-color-fill-0 dark:bg-semi-color-fill-1 ${className}`}
|
||||
>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar active size='extra-small' className='shadow-sm' />
|
||||
}
|
||||
/>
|
||||
<div className='ml-1.5 mr-1'>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: isMobile ? 15 : width, height: 12 }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// Logo图片骨架屏
|
||||
const renderImageSkeleton = () => {
|
||||
return (
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Image
|
||||
active
|
||||
className={`absolute inset-0 !rounded-full ${className}`}
|
||||
style={{ width: '100%', height: '100%' }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
// 系统名称骨架屏
|
||||
const renderTitleSkeleton = () => {
|
||||
return (
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={<Skeleton.Title active style={{ width, height: 24 }} />}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
// 通用文本骨架屏
|
||||
const renderTextSkeleton = () => {
|
||||
return (
|
||||
<div className={className}>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={<Skeleton.Title active style={{ width, height }} />}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 按钮骨架屏(支持圆角)
|
||||
const renderButtonSkeleton = () => {
|
||||
return (
|
||||
<div className={className}>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width, height, borderRadius: 9999 }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 侧边栏导航项骨架屏 (图标 + 文本)
|
||||
const renderSidebarNavItemSkeleton = () => {
|
||||
return Array(count)
|
||||
.fill(null)
|
||||
.map((_, index) => (
|
||||
<div
|
||||
key={index}
|
||||
className={`flex items-center p-2 mb-1 rounded-md ${className}`}
|
||||
>
|
||||
{/* 图标骨架屏 */}
|
||||
<div className='sidebar-icon-container flex-shrink-0'>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar active size='extra-small' shape='square' />
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
{/* 文本骨架屏 */}
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: width || 80, height: height || 14 }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
));
|
||||
};
|
||||
|
||||
// 侧边栏组标题骨架屏
|
||||
const renderSidebarGroupTitleSkeleton = () => {
|
||||
return (
|
||||
<div className={`mb-2 ${className}`}>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: width || 60, height: height || 12 }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 完整侧边栏骨架屏 - 1:1 还原,去重实现
|
||||
const renderSidebarSkeleton = () => {
|
||||
const NAV_WIDTH = 164;
|
||||
const NAV_HEIGHT = 30;
|
||||
const COLLAPSED_WIDTH = 44;
|
||||
const COLLAPSED_HEIGHT = 44;
|
||||
const ICON_SIZE = 16;
|
||||
const TITLE_HEIGHT = 12;
|
||||
const TEXT_HEIGHT = 16;
|
||||
|
||||
const renderIcon = () => (
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar
|
||||
active
|
||||
shape='square'
|
||||
style={{ width: ICON_SIZE, height: ICON_SIZE }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
|
||||
const renderLabel = (labelWidth) => (
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: labelWidth, height: TEXT_HEIGHT }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
);
|
||||
|
||||
const NavRow = ({ labelWidth }) => (
|
||||
<div
|
||||
className='flex items-center p-2 mb-1 rounded-md'
|
||||
style={{
|
||||
width: `${NAV_WIDTH}px`,
|
||||
height: `${NAV_HEIGHT}px`,
|
||||
margin: '3px 8px',
|
||||
}}
|
||||
>
|
||||
<div className='sidebar-icon-container flex-shrink-0'>
|
||||
{renderIcon()}
|
||||
</div>
|
||||
{renderLabel(labelWidth)}
|
||||
</div>
|
||||
);
|
||||
|
||||
const CollapsedRow = ({ keyPrefix, index }) => (
|
||||
<div
|
||||
key={`${keyPrefix}-${index}`}
|
||||
className='flex items-center justify-center'
|
||||
style={{
|
||||
width: `${COLLAPSED_WIDTH}px`,
|
||||
height: `${COLLAPSED_HEIGHT}px`,
|
||||
margin: '0 8px 4px 8px',
|
||||
}}
|
||||
>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Avatar
|
||||
active
|
||||
shape='square'
|
||||
style={{ width: ICON_SIZE, height: ICON_SIZE }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
if (collapsed) {
|
||||
return (
|
||||
<div className={`w-full ${className}`} style={{ paddingTop: '12px' }}>
|
||||
{Array(2)
|
||||
.fill(null)
|
||||
.map((_, i) => (
|
||||
<CollapsedRow keyPrefix='c-chat' index={i} />
|
||||
))}
|
||||
{Array(5)
|
||||
.fill(null)
|
||||
.map((_, i) => (
|
||||
<CollapsedRow keyPrefix='c-console' index={i} />
|
||||
))}
|
||||
{Array(2)
|
||||
.fill(null)
|
||||
.map((_, i) => (
|
||||
<CollapsedRow keyPrefix='c-personal' index={i} />
|
||||
))}
|
||||
{Array(5)
|
||||
.fill(null)
|
||||
.map((_, i) => (
|
||||
<CollapsedRow keyPrefix='c-admin' index={i} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const sections = [
|
||||
{ key: 'chat', titleWidth: 32, itemWidths: [54, 32], wrapper: 'section' },
|
||||
{ key: 'console', titleWidth: 48, itemWidths: [64, 64, 64, 64, 64] },
|
||||
{ key: 'personal', titleWidth: 64, itemWidths: [64, 64] },
|
||||
...(showAdmin
|
||||
? [{ key: 'admin', titleWidth: 48, itemWidths: [64, 64, 80, 64, 64] }]
|
||||
: []),
|
||||
];
|
||||
|
||||
return (
|
||||
<div className={`w-full ${className}`} style={{ paddingTop: '12px' }}>
|
||||
{sections.map((sec, idx) => (
|
||||
<React.Fragment key={sec.key}>
|
||||
{sec.wrapper === 'section' ? (
|
||||
<div className='sidebar-section'>
|
||||
<div
|
||||
className='sidebar-group-label'
|
||||
style={{ padding: '4px 15px 8px' }}
|
||||
>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
{sec.itemWidths.map((w, i) => (
|
||||
<NavRow key={`${sec.key}-${i}`} labelWidth={w} />
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div>
|
||||
<div
|
||||
className='sidebar-group-label'
|
||||
style={{ padding: '4px 15px 8px' }}
|
||||
>
|
||||
<Skeleton
|
||||
loading={true}
|
||||
active
|
||||
placeholder={
|
||||
<Skeleton.Title
|
||||
active
|
||||
style={{ width: sec.titleWidth, height: TITLE_HEIGHT }}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
{sec.itemWidths.map((w, i) => (
|
||||
<NavRow key={`${sec.key}-${i}`} labelWidth={w} />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// 根据类型渲染不同的骨架屏
|
||||
switch (type) {
|
||||
case 'navigation':
|
||||
return renderNavigationSkeleton();
|
||||
case 'userArea':
|
||||
return renderUserAreaSkeleton();
|
||||
case 'image':
|
||||
return renderImageSkeleton();
|
||||
case 'title':
|
||||
return renderTitleSkeleton();
|
||||
case 'sidebarNavItem':
|
||||
return renderSidebarNavItemSkeleton();
|
||||
case 'sidebarGroupTitle':
|
||||
return renderSidebarGroupTitleSkeleton();
|
||||
case 'sidebar':
|
||||
return renderSidebarSkeleton();
|
||||
case 'button':
|
||||
return renderButtonSkeleton();
|
||||
case 'text':
|
||||
default:
|
||||
return renderTextSkeleton();
|
||||
}
|
||||
};
|
||||
|
||||
export default SkeletonWrapper;
|
||||
@@ -20,7 +20,7 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
import React from 'react';
|
||||
import { Link } from 'react-router-dom';
|
||||
import { Typography, Tag } from '@douyinfe/semi-ui';
|
||||
import SkeletonWrapper from './SkeletonWrapper';
|
||||
import SkeletonWrapper from '../components/SkeletonWrapper';
|
||||
|
||||
const HeaderLogo = ({
|
||||
isMobile,
|
||||
@@ -19,9 +19,15 @@ For commercial licensing, please contact support@quantumnous.com
|
||||
|
||||
import React from 'react';
|
||||
import { Link } from 'react-router-dom';
|
||||
import SkeletonWrapper from './SkeletonWrapper';
|
||||
import SkeletonWrapper from '../components/SkeletonWrapper';
|
||||
|
||||
const Navigation = ({ mainNavLinks, isMobile, isLoading, userState }) => {
|
||||
const Navigation = ({
|
||||
mainNavLinks,
|
||||
isMobile,
|
||||
isLoading,
|
||||
userState,
|
||||
pricingRequireAuth,
|
||||
}) => {
|
||||
const renderNavLinks = () => {
|
||||
const baseClasses =
|
||||
'flex-shrink-0 flex items-center gap-1 font-semibold rounded-md transition-all duration-200 ease-in-out';
|
||||
@@ -51,6 +57,9 @@ const Navigation = ({ mainNavLinks, isMobile, isLoading, userState }) => {
|
||||
if (link.itemKey === 'console' && !userState.user) {
|
||||
targetPath = '/login';
|
||||
}
|
||||
if (link.itemKey === 'pricing' && pricingRequireAuth && !userState.user) {
|
||||
targetPath = '/login';
|
||||
}
|
||||
|
||||
return (
|
||||
<Link key={link.itemKey} to={targetPath} className={commonLinkClasses}>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user