diff --git a/.env.example b/.env.example index 72645404e..c7851385b 100644 --- a/.env.example +++ b/.env.example @@ -56,8 +56,6 @@ # SESSION_SECRET=random_string # 其他配置 -# 渠道测试频率(单位:秒) -# CHANNEL_TEST_FREQUENCY=10 # 生成默认token # GENERATE_DEFAULT_TOKEN=false # Cohere 安全设置 diff --git a/README.md b/README.md index 45b048340..d68b3e135 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,11 @@ New API提供了丰富的功能,详细特性请参考[特性说明](https://do - 添加后缀 `-thinking` 启用思考模式 (例如: `claude-3-7-sonnet-20250219-thinking`) 16. 🔄 思考转内容功能 17. 🔄 针对用户的模型限流功能 -18. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: +18. 🔄 请求格式转换功能,支持以下三种格式转换: + 1. OpenAI Chat Completions => Claude Messages + 2. Clade Messages => OpenAI Chat Completions (可用于Claude Code调用第三方模型) + 3. OpenAI Chat Completions => Gemini Chat +19. 💰 缓存计费支持,开启后可以在缓存命中时按照设定的比例计费: 1. 在 `系统设置-运营设置` 中设置 `提示缓存倍率` 选项 2. 在渠道中设置 `提示缓存倍率`,范围 0-1,例如设置为 0.5 表示缓存命中时按照 50% 计费 3. 支持的渠道: diff --git a/common/copy.go b/common/copy.go index 8573d6e0b..3edb2fa25 100644 --- a/common/copy.go +++ b/common/copy.go @@ -2,7 +2,8 @@ package common import ( "fmt" - "github.com/antlabs/pcopy" + + "github.com/jinzhu/copier" ) func DeepCopy[T any](src *T) (*T, error) { @@ -10,12 +11,9 @@ func DeepCopy[T any](src *T) (*T, error) { return nil, fmt.Errorf("copy source cannot be nil") } var dst T - err := pcopy.Copy(&dst, src) + err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true}) if err != nil { return nil, err } - if &dst == nil { - return nil, fmt.Errorf("copy result cannot be nil") - } return &dst, nil } diff --git a/common/json.go b/common/json.go index 69aa952e9..13e23a460 100644 --- a/common/json.go +++ b/common/json.go @@ -20,3 +20,25 @@ func DecodeJson(reader *bytes.Reader, v any) error { func Marshal(v any) ([]byte, error) { return json.Marshal(v) } + +func GetJsonType(data json.RawMessage) string { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return "unknown" + } + firstChar := bytes.TrimSpace(data)[0] + switch firstChar { + case '{': + return "object" + case '[': + return "array" + case '"': + return "string" + case 't', 'f': + return "boolean" + case 'n': + return "null" + default: + return "number" + } +} diff --git a/common/utils.go b/common/utils.go index 17aecd950..883abfd1a 100644 --- a/common/utils.go +++ b/common/utils.go @@ -123,8 +123,16 @@ func Interface2String(inter interface{}) string { return fmt.Sprintf("%d", inter.(int)) case float64: return fmt.Sprintf("%f", inter.(float64)) + case bool: + if inter.(bool) { + return "true" + } else { + return "false" + } + case nil: + return "" } - return "Not Implemented" + return fmt.Sprintf("%v", inter) } func UnescapeHTML(x string) interface{} { @@ -257,32 +265,32 @@ func GetAudioDuration(ctx context.Context, filename string, ext string) (float64 if err != nil { return 0, errors.Wrap(err, "failed to get audio duration") } - durationStr := string(bytes.TrimSpace(output)) - if durationStr == "N/A" { - // Create a temporary output file name - tmpFp, err := os.CreateTemp("", "audio-*"+ext) - if err != nil { - return 0, errors.Wrap(err, "failed to create temporary file") - } - tmpName := tmpFp.Name() - // Close immediately so ffmpeg can open the file on Windows. - _ = tmpFp.Close() - defer os.Remove(tmpName) + durationStr := string(bytes.TrimSpace(output)) + if durationStr == "N/A" { + // Create a temporary output file name + tmpFp, err := os.CreateTemp("", "audio-*"+ext) + if err != nil { + return 0, errors.Wrap(err, "failed to create temporary file") + } + tmpName := tmpFp.Name() + // Close immediately so ffmpeg can open the file on Windows. + _ = tmpFp.Close() + defer os.Remove(tmpName) - // ffmpeg -y -i filename -vcodec copy -acodec copy - ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName) - if err := ffmpegCmd.Run(); err != nil { - return 0, errors.Wrap(err, "failed to run ffmpeg") - } + // ffmpeg -y -i filename -vcodec copy -acodec copy + ffmpegCmd := exec.CommandContext(ctx, "ffmpeg", "-y", "-i", filename, "-vcodec", "copy", "-acodec", "copy", tmpName) + if err := ffmpegCmd.Run(); err != nil { + return 0, errors.Wrap(err, "failed to run ffmpeg") + } - // Recalculate the duration of the new file - c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName) - output, err := c.Output() - if err != nil { - return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg") - } - durationStr = string(bytes.TrimSpace(output)) - } + // Recalculate the duration of the new file + c = exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", tmpName) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration after ffmpeg") + } + durationStr = string(bytes.TrimSpace(output)) + } return strconv.ParseFloat(durationStr, 64) } diff --git a/controller/channel-test.go b/controller/channel-test.go index 81f7e19ab..5a668c488 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,7 @@ import ( relayconstant "one-api/relay/constant" "one-api/relay/helper" "one-api/service" + "one-api/setting/operation_setting" "one-api/types" "strconv" "strings" @@ -234,7 +235,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - err := service.RelayErrorHandler(httpResp, true) + err := service.RelayErrorHandler(c.Request.Context(), httpResp, true) return testResult{ context: c, localErr: err, @@ -477,15 +478,26 @@ func TestAllChannels(c *gin.Context) { return } -func AutomaticallyTestChannels(frequency int) { - if frequency <= 0 { - common.SysLog("CHANNEL_TEST_FREQUENCY is not set or invalid, skipping automatic channel test") - return - } - for { - time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") - _ = testAllChannels(false) - common.SysLog("channel test finished") - } +var autoTestChannelsOnce sync.Once + +func AutomaticallyTestChannels() { + autoTestChannelsOnce.Do(func() { + for { + if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { + time.Sleep(10 * time.Minute) + continue + } + frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes + common.SysLog(fmt.Sprintf("automatically test channels with interval %d minutes", frequency)) + for { + time.Sleep(time.Duration(frequency) * time.Minute) + common.SysLog("automatically testing all channels") + _ = testAllChannels(false) + common.SysLog("automatically channel test finished") + if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { + break + } + } + } + }) } diff --git a/controller/channel.go b/controller/channel.go index 020a3327a..70be91d42 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -380,6 +380,85 @@ func GetChannel(c *gin.Context) { return } +// GetChannelKey 验证2FA后获取渠道密钥 +func GetChannelKey(c *gin.Context) { + type GetChannelKeyRequest struct { + Code string `json:"code" binding:"required"` + } + + var req GetChannelKeyRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiError(c, fmt.Errorf("参数错误: %v", err)) + return + } + + userId := c.GetInt("id") + channelId, err := strconv.Atoi(c.Param("id")) + if err != nil { + common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err)) + return + } + + // 获取2FA记录并验证 + twoFA, err := model.GetTwoFAByUserId(userId) + if err != nil { + common.ApiError(c, fmt.Errorf("获取2FA信息失败: %v", err)) + return + } + + if twoFA == nil || !twoFA.IsEnabled { + common.ApiError(c, fmt.Errorf("用户未启用2FA,无法查看密钥")) + return + } + + // 统一的2FA验证逻辑 + if !validateTwoFactorAuth(twoFA, req.Code) { + common.ApiError(c, fmt.Errorf("验证码或备用码错误,请重试")) + return + } + + // 获取渠道信息(包含密钥) + channel, err := model.GetChannelById(channelId, true) + if err != nil { + common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err)) + return + } + + if channel == nil { + common.ApiError(c, fmt.Errorf("渠道不存在")) + return + } + + // 记录操作日志 + model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId)) + + // 统一的成功响应格式 + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "验证成功", + "data": map[string]interface{}{ + "key": channel.Key, + }, + }) +} + +// validateTwoFactorAuth 统一的2FA验证函数 +func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool { + // 尝试验证TOTP + if cleanCode, err := common.ValidateNumericCode(code); err == nil { + if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid { + return true + } + } + + // 尝试验证备用码 + if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid { + return true + } + + return false +} + // validateChannel 通用的渠道校验函数 func validateChannel(channel *model.Channel, isAdd bool) error { // 校验 channel settings diff --git a/controller/misc.go b/controller/misc.go index f30ab8c79..897dad254 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -39,6 +39,8 @@ func TestStatus(c *gin.Context) { func GetStatus(c *gin.Context) { cs := console_setting.GetConsoleSetting() + common.OptionMapRWMutex.RLock() + defer common.OptionMapRWMutex.RUnlock() data := gin.H{ "version": common.Version, @@ -89,6 +91,10 @@ func GetStatus(c *gin.Context) { "announcements_enabled": cs.AnnouncementsEnabled, "faq_enabled": cs.FAQEnabled, + // 模块管理配置 + "HeaderNavModules": common.OptionMap["HeaderNavModules"], + "SidebarModulesAdmin": common.OptionMap["SidebarModulesAdmin"], + "oidc_enabled": system_setting.GetOIDCSettings().Enabled, "oidc_client_id": system_setting.GetOIDCSettings().ClientId, "oidc_authorization_endpoint": system_setting.GetOIDCSettings().AuthorizationEndpoint, diff --git a/controller/model.go b/controller/model.go index 398503e8d..f0571b995 100644 --- a/controller/model.go +++ b/controller/model.go @@ -207,6 +207,7 @@ func ListModels(c *gin.Context, modelType int) { c.JSON(200, gin.H{ "success": true, "data": userOpenAiModels, + "object": "list", }) } } diff --git a/controller/model_sync.go b/controller/model_sync.go new file mode 100644 index 000000000..74034b51a --- /dev/null +++ b/controller/model_sync.go @@ -0,0 +1,604 @@ +package controller + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "strings" + "sync" + "time" + + "one-api/common" + "one-api/model" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +// 上游地址 +const ( + upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json" + upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json" +) + +func normalizeLocale(locale string) (string, bool) { + l := strings.ToLower(strings.TrimSpace(locale)) + switch l { + case "en", "zh", "ja": + return l, true + default: + return "", false + } +} + +func getUpstreamBase() string { + return common.GetEnvOrDefaultString("SYNC_UPSTREAM_BASE", "https://basellm.github.io/llm-metadata") +} + +func getUpstreamURLs(locale string) (modelsURL, vendorsURL string) { + base := strings.TrimRight(getUpstreamBase(), "/") + if l, ok := normalizeLocale(locale); ok && l != "" { + return fmt.Sprintf("%s/api/i18n/%s/newapi/models.json", base, l), + fmt.Sprintf("%s/api/i18n/%s/newapi/vendors.json", base, l) + } + return fmt.Sprintf("%s/api/newapi/models.json", base), fmt.Sprintf("%s/api/newapi/vendors.json", base) +} + +type upstreamEnvelope[T any] struct { + Success bool `json:"success"` + Message string `json:"message"` + Data []T `json:"data"` +} + +type upstreamModel struct { + Description string `json:"description"` + Endpoints json.RawMessage `json:"endpoints"` + Icon string `json:"icon"` + ModelName string `json:"model_name"` + NameRule int `json:"name_rule"` + Status int `json:"status"` + Tags string `json:"tags"` + VendorName string `json:"vendor_name"` +} + +type upstreamVendor struct { + Description string `json:"description"` + Icon string `json:"icon"` + Name string `json:"name"` + Status int `json:"status"` +} + +var ( + etagCache = make(map[string]string) + bodyCache = make(map[string][]byte) + cacheMutex sync.RWMutex +) + +type overwriteField struct { + ModelName string `json:"model_name"` + Fields []string `json:"fields"` +} + +type syncRequest struct { + Overwrite []overwriteField `json:"overwrite"` + Locale string `json:"locale"` +} + +func newHTTPClient() *http.Client { + timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 10) + dialer := &net.Dialer{Timeout: time.Duration(timeoutSec) * time.Second} + transport := &http.Transport{ + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: time.Duration(timeoutSec) * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: time.Duration(timeoutSec) * time.Second, + } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + if strings.HasSuffix(host, "github.io") { + if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { + return conn, nil + } + return dialer.DialContext(ctx, "tcp6", addr) + } + return dialer.DialContext(ctx, network, addr) + } + return &http.Client{Transport: transport} +} + +var httpClient = newHTTPClient() + +func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error { + var lastErr error + attempts := common.GetEnvOrDefault("SYNC_HTTP_RETRY", 3) + if attempts < 1 { + attempts = 1 + } + baseDelay := 200 * time.Millisecond + maxMB := common.GetEnvOrDefault("SYNC_HTTP_MAX_MB", 10) + maxBytes := int64(maxMB) << 20 + for attempt := 0; attempt < attempts; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + // ETag conditional request + cacheMutex.RLock() + if et := etagCache[url]; et != "" { + req.Header.Set("If-None-Match", et) + } + cacheMutex.RUnlock() + + resp, err := httpClient.Do(req) + if err != nil { + lastErr = err + // backoff with jitter + sleep := baseDelay * time.Duration(1< id + vendorIDCache := make(map[string]int) + + for _, name := range missing { + up, ok := modelByName[name] + if !ok { + skipped = append(skipped, name) + continue + } + + // 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时) + var existing model.Model + if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil { + if existing.SyncOfficial == 0 { + skipped = append(skipped, name) + continue + } + } + + // 确保 vendor 存在 + vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) + + // 创建模型 + mi := &model.Model{ + ModelName: name, + Description: up.Description, + Icon: up.Icon, + Tags: up.Tags, + VendorID: vendorID, + Status: chooseStatus(up.Status, 1), + NameRule: up.NameRule, + } + if err := mi.Insert(); err == nil { + createdModels++ + createdList = append(createdList, name) + } else { + skipped = append(skipped, name) + } + } + + // 4) 处理可选覆盖(更新本地已有模型的差异字段) + if len(req.Overwrite) > 0 { + // vendorIDCache 已用于创建阶段,可复用 + for _, ow := range req.Overwrite { + up, ok := modelByName[ow.ModelName] + if !ok { + continue + } + var local model.Model + if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil { + continue + } + + // 跳过被禁用官方同步的模型 + if local.SyncOfficial == 0 { + continue + } + + // 映射 vendor + newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) + + // 应用字段覆盖(事务) + _ = model.DB.Transaction(func(tx *gorm.DB) error { + needUpdate := false + if containsField(ow.Fields, "description") { + local.Description = up.Description + needUpdate = true + } + if containsField(ow.Fields, "icon") { + local.Icon = up.Icon + needUpdate = true + } + if containsField(ow.Fields, "tags") { + local.Tags = up.Tags + needUpdate = true + } + if containsField(ow.Fields, "vendor") { + local.VendorID = newVendorID + needUpdate = true + } + if containsField(ow.Fields, "name_rule") { + local.NameRule = up.NameRule + needUpdate = true + } + if containsField(ow.Fields, "status") { + local.Status = chooseStatus(up.Status, local.Status) + needUpdate = true + } + if !needUpdate { + return nil + } + if err := tx.Save(&local).Error; err != nil { + return err + } + updatedModels++ + updatedList = append(updatedList, ow.ModelName) + return nil + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "created_models": createdModels, + "created_vendors": createdVendors, + "updated_models": updatedModels, + "skipped_models": skipped, + "created_list": createdList, + "updated_list": updatedList, + "source": gin.H{ + "locale": req.Locale, + "models_url": modelsURL, + "vendors_url": vendorsURL, + }, + }, + }) +} + +func containsField(fields []string, key string) bool { + key = strings.ToLower(strings.TrimSpace(key)) + for _, f := range fields { + if strings.ToLower(strings.TrimSpace(f)) == key { + return true + } + } + return false +} + +func coalesce(a, b string) string { + if strings.TrimSpace(a) != "" { + return a + } + return b +} + +func chooseStatus(primary, fallback int) int { + if primary == 0 && fallback != 0 { + return fallback + } + if primary != 0 { + return primary + } + return 1 +} + +// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择) +func SyncUpstreamPreview(c *gin.Context) { + // 1) 拉取上游数据 + timeoutSec := common.GetEnvOrDefault("SYNC_HTTP_TIMEOUT_SECONDS", 15) + ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(timeoutSec)*time.Second) + defer cancel() + + locale := c.Query("locale") + modelsURL, vendorsURL := getUpstreamURLs(locale) + + var vendorsEnv upstreamEnvelope[upstreamVendor] + var modelsEnv upstreamEnvelope[upstreamModel] + var fetchErr error + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _ = fetchJSON(ctx, vendorsURL, &vendorsEnv) + }() + go func() { + defer wg.Done() + if err := fetchJSON(ctx, modelsURL, &modelsEnv); err != nil { + fetchErr = err + } + }() + wg.Wait() + if fetchErr != nil { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + fetchErr.Error(), "locale": locale, "source_urls": gin.H{"models_url": modelsURL, "vendors_url": vendorsURL}}) + return + } + + vendorByName := make(map[string]upstreamVendor) + for _, v := range vendorsEnv.Data { + if v.Name != "" { + vendorByName[v.Name] = v + } + } + modelByName := make(map[string]upstreamModel) + upstreamNames := make([]string, 0, len(modelsEnv.Data)) + for _, m := range modelsEnv.Data { + if m.ModelName != "" { + modelByName[m.ModelName] = m + upstreamNames = append(upstreamNames, m.ModelName) + } + } + + // 2) 本地已有模型 + var locals []model.Model + if len(upstreamNames) > 0 { + _ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error + } + + // 本地 vendor 名称映射 + vendorIdSet := make(map[int]struct{}) + for _, m := range locals { + if m.VendorID != 0 { + vendorIdSet[m.VendorID] = struct{}{} + } + } + vendorIDs := make([]int, 0, len(vendorIdSet)) + for id := range vendorIdSet { + vendorIDs = append(vendorIDs, id) + } + idToVendorName := make(map[int]string) + if len(vendorIDs) > 0 { + var dbVendors []model.Vendor + _ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error + for _, v := range dbVendors { + idToVendorName[v.Id] = v.Name + } + } + + // 3) 缺失且上游存在的模型 + missingList, _ := model.GetMissingModels() + var missing []string + for _, name := range missingList { + if _, ok := modelByName[name]; ok { + missing = append(missing, name) + } + } + + // 4) 计算冲突字段 + type conflictField struct { + Field string `json:"field"` + Local interface{} `json:"local"` + Upstream interface{} `json:"upstream"` + } + type conflictItem struct { + ModelName string `json:"model_name"` + Fields []conflictField `json:"fields"` + } + + var conflicts []conflictItem + for _, local := range locals { + up, ok := modelByName[local.ModelName] + if !ok { + continue + } + fields := make([]conflictField, 0, 6) + if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) { + fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description}) + } + if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) { + fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon}) + } + if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) { + fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags}) + } + // vendor 对比使用名称 + localVendor := idToVendorName[local.VendorID] + if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) { + fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName}) + } + if local.NameRule != up.NameRule { + fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule}) + } + if local.Status != chooseStatus(up.Status, local.Status) { + fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status}) + } + if len(fields) > 0 { + conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields}) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "missing": missing, + "conflicts": conflicts, + "source": gin.H{ + "locale": locale, + "models_url": modelsURL, + "vendors_url": vendorsURL, + }, + }, + }) +} diff --git a/controller/option.go b/controller/option.go index decdb0d40..e5f2b75b0 100644 --- a/controller/option.go +++ b/controller/option.go @@ -2,6 +2,7 @@ package controller import ( "encoding/json" + "fmt" "net/http" "one-api/common" "one-api/model" @@ -35,8 +36,13 @@ func GetOptions(c *gin.Context) { return } +type OptionUpdateRequest struct { + Key string `json:"key"` + Value any `json:"value"` +} + func UpdateOption(c *gin.Context) { - var option model.Option + var option OptionUpdateRequest err := json.NewDecoder(c.Request.Body).Decode(&option) if err != nil { c.JSON(http.StatusBadRequest, gin.H{ @@ -45,6 +51,16 @@ func UpdateOption(c *gin.Context) { }) return } + switch option.Value.(type) { + case bool: + option.Value = common.Interface2String(option.Value.(bool)) + case float64: + option.Value = common.Interface2String(option.Value.(float64)) + case int: + option.Value = common.Interface2String(option.Value.(int)) + default: + option.Value = fmt.Sprintf("%v", option.Value) + } switch option.Key { case "GitHubOAuthEnabled": if option.Value == "true" && common.GitHubClientId == "" { @@ -104,7 +120,7 @@ func UpdateOption(c *gin.Context) { return } case "GroupRatio": - err = ratio_setting.CheckGroupRatio(option.Value) + err = ratio_setting.CheckGroupRatio(option.Value.(string)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -113,7 +129,7 @@ func UpdateOption(c *gin.Context) { return } case "ModelRequestRateLimitGroup": - err = setting.CheckModelRequestRateLimitGroup(option.Value) + err = setting.CheckModelRequestRateLimitGroup(option.Value.(string)) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -122,7 +138,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.api_info": - err = console_setting.ValidateConsoleSettings(option.Value, "ApiInfo") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "ApiInfo") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -131,7 +147,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.announcements": - err = console_setting.ValidateConsoleSettings(option.Value, "Announcements") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "Announcements") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -140,7 +156,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.faq": - err = console_setting.ValidateConsoleSettings(option.Value, "FAQ") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "FAQ") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -149,7 +165,7 @@ func UpdateOption(c *gin.Context) { return } case "console_setting.uptime_kuma_groups": - err = console_setting.ValidateConsoleSettings(option.Value, "UptimeKumaGroups") + err = console_setting.ValidateConsoleSettings(option.Value.(string), "UptimeKumaGroups") if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -158,7 +174,7 @@ func UpdateOption(c *gin.Context) { return } } - err = model.UpdateOption(option.Key, option.Value) + err = model.UpdateOption(option.Key, option.Value.(string)) if err != nil { common.ApiError(c, err) return diff --git a/controller/ratio_config.go b/controller/ratio_config.go index 6ddc3d9ef..0cb4aa73b 100644 --- a/controller/ratio_config.go +++ b/controller/ratio_config.go @@ -1,24 +1,24 @@ package controller import ( - "net/http" - "one-api/setting/ratio_setting" + "net/http" + "one-api/setting/ratio_setting" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) func GetRatioConfig(c *gin.Context) { - if !ratio_setting.IsExposeRatioEnabled() { - c.JSON(http.StatusForbidden, gin.H{ - "success": false, - "message": "倍率配置接口未启用", - }) - return - } + if !ratio_setting.IsExposeRatioEnabled() { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "倍率配置接口未启用", + }) + return + } - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": ratio_setting.GetExposedData(), - }) -} \ No newline at end of file + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": ratio_setting.GetExposedData(), + }) +} diff --git a/controller/ratio_sync.go b/controller/ratio_sync.go index 6fba0aac3..7a481c476 100644 --- a/controller/ratio_sync.go +++ b/controller/ratio_sync.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "io" + "net" "net/http" "one-api/logger" "strings" @@ -21,8 +23,26 @@ const ( defaultTimeoutSeconds = 10 defaultEndpoint = "/api/ratio_config" maxConcurrentFetches = 8 + maxRatioConfigBytes = 10 << 20 // 10MB + floatEpsilon = 1e-9 ) +func nearlyEqual(a, b float64) bool { + if a > b { + return a-b < floatEpsilon + } + return b-a < floatEpsilon +} + +func valuesEqual(a, b interface{}) bool { + af, aok := a.(float64) + bf, bok := b.(float64) + if aok && bok { + return nearlyEqual(af, bf) + } + return a == b +} + var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} type upstreamResult struct { @@ -87,7 +107,23 @@ func FetchUpstreamRatios(c *gin.Context) { sem := make(chan struct{}, maxConcurrentFetches) - client := &http.Client{Transport: &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second}} + dialer := &net.Dialer{Timeout: 10 * time.Second} + transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second} + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + // 对 github.io 优先尝试 IPv4,失败则回退 IPv6 + if strings.HasSuffix(host, "github.io") { + if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { + return conn, nil + } + return dialer.DialContext(ctx, "tcp6", addr) + } + return dialer.DialContext(ctx, network, addr) + } + client := &http.Client{Transport: transport} for _, chn := range upstreams { wg.Add(1) @@ -98,12 +134,17 @@ func FetchUpstreamRatios(c *gin.Context) { defer func() { <-sem }() endpoint := chItem.Endpoint - if endpoint == "" { - endpoint = defaultEndpoint - } else if !strings.HasPrefix(endpoint, "/") { - endpoint = "/" + endpoint + var fullURL string + if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { + fullURL = endpoint + } else { + if endpoint == "" { + endpoint = defaultEndpoint + } else if !strings.HasPrefix(endpoint, "/") { + endpoint = "/" + endpoint + } + fullURL = chItem.BaseURL + endpoint } - fullURL := chItem.BaseURL + endpoint uniqueName := chItem.Name if chItem.ID != 0 { @@ -120,10 +161,19 @@ func FetchUpstreamRatios(c *gin.Context) { return } - resp, err := client.Do(httpReq) - if err != nil { - logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+err.Error()) - ch <- upstreamResult{Name: uniqueName, Err: err.Error()} + // 简单重试:最多 3 次,指数退避 + var resp *http.Response + var lastErr error + for attempt := 0; attempt < 3; attempt++ { + resp, lastErr = client.Do(httpReq) + if lastErr == nil { + break + } + time.Sleep(time.Duration(200*(1< data 为 map[string]any,包含 model_ratio/completion_ratio/cache_ratio/model_price // type2: /api/pricing -> data 为 []Pricing 列表,需要转换为与 type1 相同的 map 格式 @@ -141,7 +197,7 @@ func FetchUpstreamRatios(c *gin.Context) { Message string `json:"message"` } - if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + if err := json.NewDecoder(limited).Decode(&body); err != nil { logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) ch <- upstreamResult{Name: uniqueName, Err: err.Error()} return @@ -152,6 +208,8 @@ func FetchUpstreamRatios(c *gin.Context) { return } + // 若 Data 为空,将继续按 type1 尝试解析(与多数静态 ratio_config 兼容) + // 尝试按 type1 解析 var type1Data map[string]any if err := json.Unmarshal(body.Data, &type1Data); err == nil { @@ -357,9 +415,9 @@ func buildDifferences(localData map[string]any, successfulChannels []struct { upstreamValue = val hasUpstreamValue = true - if localValue != nil && localValue != val { + if localValue != nil && !valuesEqual(localValue, val) { hasDifference = true - } else if localValue == val { + } else if valuesEqual(localValue, val) { upstreamValue = "same" } } @@ -466,6 +524,13 @@ func GetSyncableChannels(c *gin.Context) { } } + syncableChannels = append(syncableChannels, dto.SyncableChannel{ + ID: -100, + Name: "官方倍率预设", + BaseURL: "https://basellm.github.io", + Status: 1, + }) + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/controller/relay.go b/controller/relay.go index c055ef71e..23d725153 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -3,7 +3,6 @@ package controller import ( "bytes" "fmt" - "github.com/bytedance/gopkg/util/gopool" "io" "log" "net/http" @@ -22,6 +21,8 @@ import ( "one-api/types" "strings" + "github.com/bytedance/gopkg/util/gopool" + "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) @@ -138,15 +139,15 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) { // common.SetContextKey(c, constant.ContextKeyTokenCountMeta, meta) - preConsumedQuota, newAPIError := service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) + newAPIError = service.PreConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo) if newAPIError != nil { return } defer func() { // Only return quota if downstream failed and quota was actually pre-consumed - if newAPIError != nil && preConsumedQuota != 0 { - service.ReturnPreConsumedQuota(c, relayInfo, preConsumedQuota) + if newAPIError != nil && relayInfo.FinalPreConsumedQuota != 0 { + service.ReturnPreConsumedQuota(c, relayInfo) } }() @@ -276,14 +277,13 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) { logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error())) - - gopool.Go(func() { - // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 - // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { + // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 + // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously + if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan { + gopool.Go(func() { service.DisableChannel(channelError, err.Error()) - } - }) + }) + } if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) { // 保存错误日志到mysql中 @@ -383,11 +383,14 @@ func RelayNotFound(c *gin.Context) { func RelayTask(c *gin.Context) { retryTimes := common.RetryTimes channelId := c.GetInt("channel_id") - relayMode := c.GetInt("relay_mode") group := c.GetString("group") originalModel := c.GetString("original_model") c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) - taskErr := taskRelayHandler(c, relayMode) + relayInfo, err := relaycommon.GenRelayInfo(c, types.RelayFormatTask, nil, nil) + if err != nil { + return + } + taskErr := taskRelayHandler(c, relayInfo) if taskErr == nil { retryTimes = 0 } @@ -407,7 +410,7 @@ func RelayTask(c *gin.Context) { requestBody, _ := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - taskErr = taskRelayHandler(c, relayMode) + taskErr = taskRelayHandler(c, relayInfo) } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { @@ -422,13 +425,13 @@ func RelayTask(c *gin.Context) { } } -func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError { +func taskRelayHandler(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.TaskError { var err *dto.TaskError - switch relayMode { + switch relayInfo.RelayMode { case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID, relayconstant.RelayModeVideoFetchByID: - err = relay.RelayTaskFetch(c, relayMode) + err = relay.RelayTaskFetch(c, relayInfo.RelayMode) default: - err = relay.RelayTaskSubmit(c, relayMode) + err = relay.RelayTaskSubmit(c, relayInfo) } return err } diff --git a/controller/uptime_kuma.go b/controller/uptime_kuma.go index 05d6297eb..41b9695c3 100644 --- a/controller/uptime_kuma.go +++ b/controller/uptime_kuma.go @@ -31,7 +31,7 @@ type Monitor struct { type UptimeGroupResult struct { CategoryName string `json:"categoryName"` - Monitors []Monitor `json:"monitors"` + Monitors []Monitor `json:"monitors"` } func getAndDecode(ctx context.Context, client *http.Client, url string, dest interface{}) error { @@ -57,29 +57,29 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st url, _ := groupConfig["url"].(string) slug, _ := groupConfig["slug"].(string) categoryName, _ := groupConfig["categoryName"].(string) - + result := UptimeGroupResult{ CategoryName: categoryName, - Monitors: []Monitor{}, + Monitors: []Monitor{}, } - + if url == "" || slug == "" { return result } baseURL := strings.TrimSuffix(url, "/") - + var statusData struct { PublicGroupList []struct { - ID int `json:"id"` - Name string `json:"name"` + ID int `json:"id"` + Name string `json:"name"` MonitorList []struct { ID int `json:"id"` Name string `json:"name"` } `json:"monitorList"` } `json:"publicGroupList"` } - + var heartbeatData struct { HeartbeatList map[string][]struct { Status int `json:"status"` @@ -88,11 +88,11 @@ func fetchGroupData(ctx context.Context, client *http.Client, groupConfig map[st } g, gCtx := errgroup.WithContext(ctx) - g.Go(func() error { - return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData) + g.Go(func() error { + return getAndDecode(gCtx, client, baseURL+apiStatusPath+slug, &statusData) }) - g.Go(func() error { - return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData) + g.Go(func() error { + return getAndDecode(gCtx, client, baseURL+apiHeartbeatPath+slug, &heartbeatData) }) if g.Wait() != nil { @@ -139,7 +139,7 @@ func GetUptimeKumaStatus(c *gin.Context) { client := &http.Client{Timeout: httpTimeout} results := make([]UptimeGroupResult, len(groups)) - + g, gCtx := errgroup.WithContext(ctx) for i, group := range groups { i, group := i, group @@ -148,7 +148,7 @@ func GetUptimeKumaStatus(c *gin.Context) { return nil }) } - + g.Wait() c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": results}) -} \ No newline at end of file +} diff --git a/controller/user.go b/controller/user.go index c9795c0cb..982329cec 100644 --- a/controller/user.go +++ b/controller/user.go @@ -210,6 +210,7 @@ func Register(c *gin.Context) { Password: user.Password, DisplayName: user.Username, InviterId: inviterId, + Role: common.RoleCommonUser, // 明确设置角色为普通用户 } if common.EmailVerificationEnabled { cleanUser.Email = user.Email @@ -426,6 +427,7 @@ func GetAffCode(c *gin.Context) { func GetSelf(c *gin.Context) { id := c.GetInt("id") + userRole := c.GetInt("role") user, err := model.GetUserById(id, false) if err != nil { common.ApiError(c, err) @@ -434,14 +436,134 @@ func GetSelf(c *gin.Context) { // Hide admin remarks: set to empty to trigger omitempty tag, ensuring the remark field is not included in JSON returned to regular users user.Remark = "" + // 计算用户权限信息 + permissions := calculateUserPermissions(userRole) + + // 获取用户设置并提取sidebar_modules + userSetting := user.GetSetting() + + // 构建响应数据,包含用户信息和权限 + responseData := map[string]interface{}{ + "id": user.Id, + "username": user.Username, + "display_name": user.DisplayName, + "role": user.Role, + "status": user.Status, + "email": user.Email, + "group": user.Group, + "quota": user.Quota, + "used_quota": user.UsedQuota, + "request_count": user.RequestCount, + "aff_code": user.AffCode, + "aff_count": user.AffCount, + "aff_quota": user.AffQuota, + "aff_history_quota": user.AffHistoryQuota, + "inviter_id": user.InviterId, + "linux_do_id": user.LinuxDOId, + "setting": user.Setting, + "stripe_customer": user.StripeCustomer, + "sidebar_modules": userSetting.SidebarModules, // 正确提取sidebar_modules字段 + "permissions": permissions, // 新增权限字段 + } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": user, + "data": responseData, }) return } +// 计算用户权限的辅助函数 +func calculateUserPermissions(userRole int) map[string]interface{} { + permissions := map[string]interface{}{} + + // 根据用户角色计算权限 + if userRole == common.RoleRootUser { + // 超级管理员不需要边栏设置功能 + permissions["sidebar_settings"] = false + permissions["sidebar_modules"] = map[string]interface{}{} + } else if userRole == common.RoleAdminUser { + // 管理员可以设置边栏,但不包含系统设置功能 + permissions["sidebar_settings"] = true + permissions["sidebar_modules"] = map[string]interface{}{ + "admin": map[string]interface{}{ + "setting": false, // 管理员不能访问系统设置 + }, + } + } else { + // 普通用户只能设置个人功能,不包含管理员区域 + permissions["sidebar_settings"] = true + permissions["sidebar_modules"] = map[string]interface{}{ + "admin": false, // 普通用户不能访问管理员区域 + } + } + + return permissions +} + +// 根据用户角色生成默认的边栏配置 +func generateDefaultSidebarConfig(userRole int) string { + defaultConfig := map[string]interface{}{} + + // 聊天区域 - 所有用户都可以访问 + defaultConfig["chat"] = map[string]interface{}{ + "enabled": true, + "playground": true, + "chat": true, + } + + // 控制台区域 - 所有用户都可以访问 + defaultConfig["console"] = map[string]interface{}{ + "enabled": true, + "detail": true, + "token": true, + "log": true, + "midjourney": true, + "task": true, + } + + // 个人中心区域 - 所有用户都可以访问 + defaultConfig["personal"] = map[string]interface{}{ + "enabled": true, + "topup": true, + "personal": true, + } + + // 管理员区域 - 根据角色决定 + if userRole == common.RoleAdminUser { + // 管理员可以访问管理员区域,但不能访问系统设置 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": false, // 管理员不能访问系统设置 + } + } else if userRole == common.RoleRootUser { + // 超级管理员可以访问所有功能 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": true, + } + } + // 普通用户不包含admin区域 + + // 转换为JSON字符串 + configBytes, err := json.Marshal(defaultConfig) + if err != nil { + common.SysLog("生成默认边栏配置失败: " + err.Error()) + return "" + } + + return string(configBytes) +} + func GetUserModels(c *gin.Context) { id, err := strconv.Atoi(c.Param("id")) if err != nil { @@ -528,8 +650,8 @@ func UpdateUser(c *gin.Context) { } func UpdateSelf(c *gin.Context) { - var user model.User - err := json.NewDecoder(c.Request.Body).Decode(&user) + var requestData map[string]interface{} + err := json.NewDecoder(c.Request.Body).Decode(&requestData) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -537,6 +659,60 @@ func UpdateSelf(c *gin.Context) { }) return } + + // 检查是否是sidebar_modules更新请求 + if sidebarModules, exists := requestData["sidebar_modules"]; exists { + userId := c.GetInt("id") + user, err := model.GetUserById(userId, false) + if err != nil { + common.ApiError(c, err) + return + } + + // 获取当前用户设置 + currentSetting := user.GetSetting() + + // 更新sidebar_modules字段 + if sidebarModulesStr, ok := sidebarModules.(string); ok { + currentSetting.SidebarModules = sidebarModulesStr + } + + // 保存更新后的设置 + user.SetSetting(currentSetting) + if err := user.Update(false); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "更新设置失败: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "设置更新成功", + }) + return + } + + // 原有的用户信息更新逻辑 + var user model.User + requestDataBytes, err := json.Marshal(requestData) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + err = json.Unmarshal(requestDataBytes, &user) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的参数", + }) + return + } + if user.Password == "" { user.Password = "$I_LOVE_U" // make Validator happy :) } @@ -679,6 +855,7 @@ func CreateUser(c *gin.Context) { Username: user.Username, Password: user.Password, DisplayName: user.DisplayName, + Role: user.Role, // 保持管理员设置的角色 } if err := cleanUser.Insert(0); err != nil { common.ApiError(c, err) @@ -920,6 +1097,7 @@ type UpdateUserSettingRequest struct { WebhookUrl string `json:"webhook_url,omitempty"` WebhookSecret string `json:"webhook_secret,omitempty"` NotificationEmail string `json:"notification_email,omitempty"` + BarkUrl string `json:"bark_url,omitempty"` AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"` RecordIpLog bool `json:"record_ip_log"` } @@ -935,7 +1113,7 @@ func UpdateUserSetting(c *gin.Context) { } // 验证预警类型 - if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook { + if req.QuotaWarningType != dto.NotifyTypeEmail && req.QuotaWarningType != dto.NotifyTypeWebhook && req.QuotaWarningType != dto.NotifyTypeBark { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的预警类型", @@ -983,6 +1161,33 @@ func UpdateUserSetting(c *gin.Context) { } } + // 如果是Bark类型,验证Bark URL + if req.QuotaWarningType == dto.NotifyTypeBark { + if req.BarkUrl == "" { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Bark推送URL不能为空", + }) + return + } + // 验证URL格式 + if _, err := url.ParseRequestURI(req.BarkUrl); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "无效的Bark推送URL", + }) + return + } + // 检查是否是HTTP或HTTPS + if !strings.HasPrefix(req.BarkUrl, "https://") && !strings.HasPrefix(req.BarkUrl, "http://") { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "Bark推送URL必须以http://或https://开头", + }) + return + } + } + userId := c.GetInt("id") user, err := model.GetUserById(userId, true) if err != nil { @@ -1011,6 +1216,11 @@ func UpdateUserSetting(c *gin.Context) { settings.NotificationEmail = req.NotificationEmail } + // 如果是Bark类型,添加Bark URL到设置中 + if req.QuotaWarningType == dto.NotifyTypeBark { + settings.BarkUrl = req.BarkUrl + } + // 更新用户设置 user.SetSetting(settings) if err := user.Update(false); err != nil { diff --git a/dto/claude.go b/dto/claude.go index 5c4396f23..963e588bf 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -488,14 +488,14 @@ func (c *ClaudeResponse) GetClaudeError() *types.ClaudeError { case string: // 处理简单字符串错误 return &types.ClaudeError{ - Type: "error", + Type: "upstream_error", Message: err, } default: // 未知类型,尝试转换为字符串 return &types.ClaudeError{ - Type: "unknown_error", - Message: fmt.Sprintf("%v", err), + Type: "unknown_upstream_error", + Message: fmt.Sprintf("unknown_error: %v", err), } } } diff --git a/dto/gemini.go b/dto/gemini.go index 5df67ba0b..cd5d74cdd 100644 --- a/dto/gemini.go +++ b/dto/gemini.go @@ -2,11 +2,12 @@ package dto import ( "encoding/json" - "github.com/gin-gonic/gin" "one-api/common" "one-api/logger" "one-api/types" "strings" + + "github.com/gin-gonic/gin" ) type GeminiChatRequest struct { @@ -268,14 +269,15 @@ type GeminiChatResponse struct { } type GeminiUsageMetadata struct { - PromptTokenCount int `json:"promptTokenCount"` - CandidatesTokenCount int `json:"candidatesTokenCount"` - TotalTokenCount int `json:"totalTokenCount"` - ThoughtsTokenCount int `json:"thoughtsTokenCount"` - PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"` + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + ThoughtsTokenCount int `json:"thoughtsTokenCount"` + PromptTokensDetails []GeminiModalityTokenCount `json:"promptTokensDetails"` + CandidatesTokensDetails []GeminiModalityTokenCount `json:"candidatesTokensDetails"` } -type GeminiPromptTokensDetails struct { +type GeminiModalityTokenCount struct { Modality string `json:"modality"` TokenCount int `json:"tokenCount"` } diff --git a/dto/openai_image.go b/dto/openai_image.go index 9e838688e..5aece25f2 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -59,6 +59,31 @@ func (i *ImageRequest) UnmarshalJSON(data []byte) error { return nil } +// 序列化时需要重新把字段平铺 +func (r ImageRequest) MarshalJSON() ([]byte, error) { + // 将已定义字段转为 map + type Alias ImageRequest + alias := Alias(r) + base, err := common.Marshal(alias) + if err != nil { + return nil, err + } + + var baseMap map[string]json.RawMessage + if err := common.Unmarshal(base, &baseMap); err != nil { + return nil, err + } + + // 合并 ExtraFields + for k, v := range r.Extra { + if _, exists := baseMap[k]; !exists { + baseMap[k] = v + } + } + + return json.Marshal(baseMap) +} + func GetJSONFieldNames(t reflect.Type) map[string]struct{} { fields := make(map[string]struct{}) for i := 0; i < t.NumField(); i++ { diff --git a/dto/openai_request.go b/dto/openai_request.go index 881ec2241..cd05a63c9 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -57,18 +57,24 @@ type GeneralOpenAIRequest struct { Dimensions int `json:"dimensions,omitempty"` Modalities json.RawMessage `json:"modalities,omitempty"` Audio json.RawMessage `json:"audio,omitempty"` - EnableThinking any `json:"enable_thinking,omitempty"` // ali - THINKING json.RawMessage `json:"thinking,omitempty"` // doubao,zhipu_v4 - ExtraBody json.RawMessage `json:"extra_body,omitempty"` - SearchParameters any `json:"search_parameters,omitempty"` //xai - WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` + // gemini + ExtraBody json.RawMessage `json:"extra_body,omitempty"` + //xai + SearchParameters json.RawMessage `json:"search_parameters,omitempty"` + // claude + WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` // OpenRouter Params Usage json.RawMessage `json:"usage,omitempty"` Reasoning json.RawMessage `json:"reasoning,omitempty"` // Ali Qwen Params VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"` + EnableThinking any `json:"enable_thinking,omitempty"` // ollama Params Think json.RawMessage `json:"think,omitempty"` + // baidu v2 + WebSearch json.RawMessage `json:"web_search,omitempty"` + // doubao,zhipu_v4 + THINKING json.RawMessage `json:"thinking,omitempty"` } func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta { @@ -760,27 +766,27 @@ type WebSearchOptions struct { // https://platform.openai.com/docs/api-reference/responses/create type OpenAIResponsesRequest struct { - Model string `json:"model"` - Input any `json:"input,omitempty"` - Include json.RawMessage `json:"include,omitempty"` - Instructions json.RawMessage `json:"instructions,omitempty"` - MaxOutputTokens uint `json:"max_output_tokens,omitempty"` - Metadata json.RawMessage `json:"metadata,omitempty"` - ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` - PreviousResponseID string `json:"previous_response_id,omitempty"` - Reasoning *Reasoning `json:"reasoning,omitempty"` - ServiceTier string `json:"service_tier,omitempty"` - Store bool `json:"store,omitempty"` - Stream bool `json:"stream,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Text json.RawMessage `json:"text,omitempty"` - ToolChoice json.RawMessage `json:"tool_choice,omitempty"` - Tools []map[string]any `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map - TopP float64 `json:"top_p,omitempty"` - Truncation string `json:"truncation,omitempty"` - User string `json:"user,omitempty"` - MaxToolCalls uint `json:"max_tool_calls,omitempty"` - Prompt json.RawMessage `json:"prompt,omitempty"` + Model string `json:"model"` + Input json.RawMessage `json:"input,omitempty"` + Include json.RawMessage `json:"include,omitempty"` + Instructions json.RawMessage `json:"instructions,omitempty"` + MaxOutputTokens uint `json:"max_output_tokens,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` + ParallelToolCalls bool `json:"parallel_tool_calls,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` + Reasoning *Reasoning `json:"reasoning,omitempty"` + ServiceTier string `json:"service_tier,omitempty"` + Store bool `json:"store,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Text json.RawMessage `json:"text,omitempty"` + ToolChoice json.RawMessage `json:"tool_choice,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少,MCP 参数太多不确定,所以用 map + TopP float64 `json:"top_p,omitempty"` + Truncation string `json:"truncation,omitempty"` + User string `json:"user,omitempty"` + MaxToolCalls uint `json:"max_tool_calls,omitempty"` + Prompt json.RawMessage `json:"prompt,omitempty"` } func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { @@ -832,8 +838,7 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta { } if len(r.Tools) > 0 { - toolStr, _ := common.Marshal(r.Tools) - texts = append(texts, string(toolStr)) + texts = append(texts, string(r.Tools)) } return &types.TokenCountMeta{ @@ -853,6 +858,14 @@ func (r *OpenAIResponsesRequest) SetModelName(modelName string) { } } +func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any { + var toolsMap []map[string]any + if len(r.Tools) > 0 { + _ = common.Unmarshal(r.Tools, &toolsMap) + } + return toolsMap +} + type Reasoning struct { Effort string `json:"effort,omitempty"` Summary string `json:"summary,omitempty"` @@ -879,13 +892,21 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput { var inputs []MediaInput // Try string first - if str, ok := r.Input.(string); ok { + // if str, ok := common.GetJsonType(r.Input); ok { + // inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) + // return inputs + // } + if common.GetJsonType(r.Input) == "string" { + var str string + _ = common.Unmarshal(r.Input, &str) inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) return inputs } // Try array of parts - if array, ok := r.Input.([]any); ok { + if common.GetJsonType(r.Input) == "array" { + var array []any + _ = common.Unmarshal(r.Input, &array) for _, itemAny := range array { // Already parsed MediaInput if media, ok := itemAny.(MediaInput); ok { diff --git a/dto/ratio_sync.go b/dto/ratio_sync.go index 6315f31ae..d6bbf68e1 100644 --- a/dto/ratio_sync.go +++ b/dto/ratio_sync.go @@ -1,23 +1,23 @@ package dto type UpstreamDTO struct { - ID int `json:"id,omitempty"` - Name string `json:"name" binding:"required"` - BaseURL string `json:"base_url" binding:"required"` - Endpoint string `json:"endpoint"` + ID int `json:"id,omitempty"` + Name string `json:"name" binding:"required"` + BaseURL string `json:"base_url" binding:"required"` + Endpoint string `json:"endpoint"` } type UpstreamRequest struct { - ChannelIDs []int64 `json:"channel_ids"` - Upstreams []UpstreamDTO `json:"upstreams"` - Timeout int `json:"timeout"` + ChannelIDs []int64 `json:"channel_ids"` + Upstreams []UpstreamDTO `json:"upstreams"` + Timeout int `json:"timeout"` } // TestResult 上游测试连通性结果 type TestResult struct { - Name string `json:"name"` - Status string `json:"status"` - Error string `json:"error,omitempty"` + Name string `json:"name"` + Status string `json:"status"` + Error string `json:"error,omitempty"` } // DifferenceItem 差异项 @@ -25,14 +25,14 @@ type TestResult struct { // Upstreams 为各渠道的上游值,具体数值 / "same" / nil type DifferenceItem struct { - Current interface{} `json:"current"` - Upstreams map[string]interface{} `json:"upstreams"` - Confidence map[string]bool `json:"confidence"` + Current interface{} `json:"current"` + Upstreams map[string]interface{} `json:"upstreams"` + Confidence map[string]bool `json:"confidence"` } type SyncableChannel struct { - ID int `json:"id"` - Name string `json:"name"` - BaseURL string `json:"base_url"` - Status int `json:"status"` -} \ No newline at end of file + ID int `json:"id"` + Name string `json:"name"` + BaseURL string `json:"base_url"` + Status int `json:"status"` +} diff --git a/dto/user_settings.go b/dto/user_settings.go index 2e1a15418..89dd926ef 100644 --- a/dto/user_settings.go +++ b/dto/user_settings.go @@ -6,11 +6,14 @@ type UserSetting struct { WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址 WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥 NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址 + BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型 RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP + SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置 } var ( NotifyTypeEmail = "email" // Email 邮件 NotifyTypeWebhook = "webhook" // Webhook + NotifyTypeBark = "bark" // Bark 推送 ) diff --git a/go.mod b/go.mod index 1a92947e5..501d966d5 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 + github.com/jinzhu/copier v0.4.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 github.com/pquerna/otp v1.5.0 @@ -44,11 +45,7 @@ require ( ) require ( - github.com/Masterminds/goutils v1.1.1 // indirect - github.com/Masterminds/semver/v3 v3.2.0 // indirect - github.com/Masterminds/sprig/v3 v3.2.3 // indirect github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect - github.com/antlabs/pcopy v0.1.5 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.2 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.2 // indirect @@ -73,8 +70,6 @@ require ( github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect - github.com/huandu/xstrings v1.3.3 // indirect - github.com/imdario/mergo v0.3.11 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.1 // indirect @@ -85,14 +80,11 @@ require ( github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mitchellh/copystructure v1.0.0 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/mitchellh/reflectwalk v1.0.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/spf13/cast v1.3.1 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect diff --git a/go.sum b/go.sum index 7b8104b95..189d09de4 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,11 @@ github.com/Calcium-Ion/go-epay v0.0.4 h1:C96M7WfRLadcIVscWzwLiYs8etI1wrDmtFMuK2zP22A= github.com/Calcium-Ion/go-epay v0.0.4/go.mod h1:cxo/ZOg8ClvE3VAnCmEzbuyAZINSq7kFEN9oHj5WQ2U= -github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= -github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= -github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g= -github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= -github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= -github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+KcxaMk1lfrRnwCd1UUuOjJM/lri5eM1qMs= github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= -github.com/antlabs/pcopy v0.1.5 h1:5Fa1ExY9T6ar3ysAi4rzB5jiYg72Innm+/ESEIOSHvQ= -github.com/antlabs/pcopy v0.1.5/go.mod h1:2FvdkPD3cFiM1CjGuXFCDQZqhKVcLI7IzeSJ2xUIOOI= github.com/aws/aws-sdk-go-v2 v1.37.2 h1:xkW1iMYawzcmYFYEV0UCMxc8gSsjCGEhBXQkdQywVbo= github.com/aws/aws-sdk-go-v2 v1.37.2/go.mod h1:9Q0OoGQoboYIAJyslFyF1f5K1Ryddop8gqMhWx/n4Wg= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.0 h1:6GMWV6CNpA/6fbFHnoAjrv4+LGfyTqZz2LtCHnspgDg= @@ -110,7 +102,6 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= -github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= @@ -121,10 +112,6 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= -github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= -github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -133,6 +120,8 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -163,12 +152,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= -github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= -github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -201,19 +186,14 @@ github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI= github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= -github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= -github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= -github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -251,36 +231,25 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.12.0 h1:UsYJhbzPYGsT0HbEdmYcqtCv8UNGvnaL561NnIUvaKg= golang.org/x/arch v0.12.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/image v0.23.0 h1:HseQ7c2OpPKTPVzNjG5fwJsOTCiiwS4QdsYi5XU6H68= golang.org/x/image v0.23.0/go.mod h1:wJJBTdLfCCf3tiHa1fNxpZmUI4mmoZvwMCPP0ddoNKY= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -289,29 +258,18 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= @@ -326,7 +284,6 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkep gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 91311b867..0caf53617 100644 --- a/main.go +++ b/main.go @@ -94,13 +94,9 @@ func main() { } go controller.AutomaticallyUpdateChannels(frequency) } - if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { - frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) - if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) - } - go controller.AutomaticallyTestChannels(frequency) - } + + go controller.AutomaticallyTestChannels() + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() diff --git a/middleware/disable-cache.go b/middleware/disable-cache.go new file mode 100644 index 000000000..3076e90a8 --- /dev/null +++ b/middleware/disable-cache.go @@ -0,0 +1,12 @@ +package middleware + +import "github.com/gin-gonic/gin" + +func DisableCache() gin.HandlerFunc { + return func(c *gin.Context) { + c.Header("Cache-Control", "no-store, no-cache, must-revalidate, private, max-age=0") + c.Header("Pragma", "no-cache") + c.Header("Expires", "0") + c.Next() + } +} diff --git a/middleware/stats.go b/middleware/stats.go index 1c97983f7..e49e56991 100644 --- a/middleware/stats.go +++ b/middleware/stats.go @@ -18,12 +18,12 @@ func StatsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // 增加活跃连接数 atomic.AddInt64(&globalStats.activeConnections, 1) - + // 确保在请求结束时减少连接数 defer func() { atomic.AddInt64(&globalStats.activeConnections, -1) }() - + c.Next() } } @@ -38,4 +38,4 @@ func GetStats() StatsInfo { return StatsInfo{ ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections), } -} \ No newline at end of file +} diff --git a/model/channel.go b/model/channel.go index 7aa7d0975..a61b3eccf 100644 --- a/model/channel.go +++ b/model/channel.go @@ -47,6 +47,7 @@ type Channel struct { Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 ParamOverride *string `json:"param_override" gorm:"type:text"` HeaderOverride *string `json:"header_override" gorm:"type:text"` + Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` // add after v0.8.5 ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` @@ -112,6 +113,10 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey) } + lock := GetChannelPollingLock(channel.Id) + lock.Lock() + defer lock.Unlock() + statusList := channel.ChannelInfo.MultiKeyStatusList // helper to get key status, default to enabled when missing getStatus := func(idx int) int { @@ -143,9 +148,6 @@ func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { return keys[selectedIdx], selectedIdx, nil case constant.MultiKeyModePolling: // Use channel-specific lock to ensure thread-safe polling - lock := GetChannelPollingLock(channel.Id) - lock.Lock() - defer lock.Unlock() channelInfo, err := CacheGetChannelInfo(channel.Id) if err != nil { @@ -605,8 +607,12 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri return false } if channelCache.ChannelInfo.IsMultiKey { + // Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey + pollingLock := GetChannelPollingLock(channelId) + pollingLock.Lock() // 如果是多Key模式,更新缓存中的状态 handlerMultiKeyUpdate(channelCache, usingKey, status, reason) + pollingLock.Unlock() //CacheUpdateChannel(channelCache) //return true } else { @@ -637,7 +643,11 @@ func UpdateChannelStatus(channelId int, usingKey string, status int, reason stri if channel.ChannelInfo.IsMultiKey { beforeStatus := channel.Status + // Protect map writes with the same per-channel lock used by readers + pollingLock := GetChannelPollingLock(channelId) + pollingLock.Lock() handlerMultiKeyUpdate(channel, usingKey, status, reason) + pollingLock.Unlock() if beforeStatus != channel.Status { shouldUpdateAbilities = true } diff --git a/model/main.go b/model/main.go index dbf271521..1a38d371b 100644 --- a/model/main.go +++ b/model/main.go @@ -64,22 +64,6 @@ var DB *gorm.DB var LOG_DB *gorm.DB -// dropIndexIfExists drops a MySQL index only if it exists to avoid noisy 1091 errors -func dropIndexIfExists(tableName string, indexName string) { - if !common.UsingMySQL { - return - } - var count int64 - // Check index existence via information_schema - err := DB.Raw( - "SELECT COUNT(1) FROM information_schema.statistics WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", - tableName, indexName, - ).Scan(&count).Error - if err == nil && count > 0 { - _ = DB.Exec("ALTER TABLE " + tableName + " DROP INDEX " + indexName + ";").Error - } -} - func createRootAccountIfNeed() error { var user User //if user.Status != common.UserStatusEnabled { @@ -263,16 +247,6 @@ func InitLogDB() (err error) { } func migrateDB() error { - // 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录 - // 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引 (model_name, deleted_at) 冲突 - dropIndexIfExists("models", "uk_model_name") // 新版复合索引名称(若已存在) - dropIndexIfExists("models", "model_name") // 旧版列级唯一索引名称 - - dropIndexIfExists("vendors", "uk_vendor_name") // 新版复合索引名称(若已存在) - dropIndexIfExists("vendors", "name") // 旧版列级唯一索引名称 - //if !common.UsingPostgreSQL { - // return migrateDBFast() - //} err := DB.AutoMigrate( &Channel{}, &Token{}, @@ -299,13 +273,6 @@ func migrateDB() error { } func migrateDBFast() error { - // 修复旧版本留下的唯一索引,允许软删除后重新插入同名记录 - // 删除单列唯一索引(列级 UNIQUE)及早期命名方式,防止与新复合唯一索引冲突 - dropIndexIfExists("models", "uk_model_name") - dropIndexIfExists("models", "model_name") - - dropIndexIfExists("vendors", "uk_vendor_name") - dropIndexIfExists("vendors", "name") var wg sync.WaitGroup diff --git a/model/model_meta.go b/model/model_meta.go index b7602b0ec..e41cbd090 100644 --- a/model/model_meta.go +++ b/model/model_meta.go @@ -20,17 +20,18 @@ type BoundChannel struct { } type Model struct { - Id int `json:"id"` - ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name,priority:1"` - Description string `json:"description,omitempty" gorm:"type:text"` - Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` - Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"` - VendorID int `json:"vendor_id,omitempty" gorm:"index"` - Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` - Status int `json:"status" gorm:"default:1"` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - UpdatedTime int64 `json:"updated_time" gorm:"bigint"` - DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name,priority:2"` + Id int `json:"id"` + ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"` + Description string `json:"description,omitempty" gorm:"type:text"` + Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` + Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"` + VendorID int `json:"vendor_id,omitempty" gorm:"index"` + Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` + Status int `json:"status" gorm:"default:1"` + SyncOfficial int `json:"sync_official" gorm:"default:1"` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + UpdatedTime int64 `json:"updated_time" gorm:"bigint"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"` BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"` EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"` diff --git a/model/pricing.go b/model/pricing.go index 3c9349de5..c1192a3d9 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -155,9 +155,12 @@ func updatePricing() { vendorMap[vendors[i].Id] = &vendors[i] } + // 初始化默认供应商映射 + initDefaultVendorMapping(metaMap, vendorMap, enableAbilities) + // 构建对前端友好的供应商列表 - vendorsList = make([]PricingVendor, 0, len(vendors)) - for _, v := range vendors { + vendorsList = make([]PricingVendor, 0, len(vendorMap)) + for _, v := range vendorMap { vendorsList = append(vendorsList, PricingVendor{ ID: v.Id, Name: v.Name, diff --git a/model/pricing_default.go b/model/pricing_default.go new file mode 100644 index 000000000..db64cafbb --- /dev/null +++ b/model/pricing_default.go @@ -0,0 +1,128 @@ +package model + +import ( + "strings" +) + +// 简化的供应商映射规则 +var defaultVendorRules = map[string]string{ + "gpt": "OpenAI", + "dall-e": "OpenAI", + "whisper": "OpenAI", + "o1": "OpenAI", + "o3": "OpenAI", + "claude": "Anthropic", + "gemini": "Google", + "moonshot": "Moonshot", + "kimi": "Moonshot", + "chatglm": "智谱", + "glm-": "智谱", + "qwen": "阿里巴巴", + "deepseek": "DeepSeek", + "abab": "MiniMax", + "ernie": "百度", + "spark": "讯飞", + "hunyuan": "腾讯", + "command": "Cohere", + "@cf/": "Cloudflare", + "360": "360", + "yi": "零一万物", + "jina": "Jina", + "mistral": "Mistral", + "grok": "xAI", + "llama": "Meta", + "doubao": "字节跳动", + "kling": "快手", + "jimeng": "即梦", + "vidu": "Vidu", +} + +// 供应商默认图标映射 +var defaultVendorIcons = map[string]string{ + "OpenAI": "OpenAI", + "Anthropic": "Claude.Color", + "Google": "Gemini.Color", + "Moonshot": "Moonshot", + "智谱": "Zhipu.Color", + "阿里巴巴": "Qwen.Color", + "DeepSeek": "DeepSeek.Color", + "MiniMax": "Minimax.Color", + "百度": "Wenxin.Color", + "讯飞": "Spark.Color", + "腾讯": "Hunyuan.Color", + "Cohere": "Cohere.Color", + "Cloudflare": "Cloudflare.Color", + "360": "Ai360.Color", + "零一万物": "Yi.Color", + "Jina": "Jina", + "Mistral": "Mistral.Color", + "xAI": "XAI", + "Meta": "Ollama", + "字节跳动": "Doubao.Color", + "快手": "Kling.Color", + "即梦": "Jimeng.Color", + "Vidu": "Vidu", + "微软": "AzureAI", + "Microsoft": "AzureAI", + "Azure": "AzureAI", +} + +// initDefaultVendorMapping 简化的默认供应商映射 +func initDefaultVendorMapping(metaMap map[string]*Model, vendorMap map[int]*Vendor, enableAbilities []AbilityWithChannel) { + for _, ability := range enableAbilities { + modelName := ability.Model + if _, exists := metaMap[modelName]; exists { + continue + } + + // 匹配供应商 + vendorID := 0 + modelLower := strings.ToLower(modelName) + for pattern, vendorName := range defaultVendorRules { + if strings.Contains(modelLower, pattern) { + vendorID = getOrCreateVendor(vendorName, vendorMap) + break + } + } + + // 创建模型元数据 + metaMap[modelName] = &Model{ + ModelName: modelName, + VendorID: vendorID, + Status: 1, + NameRule: NameRuleExact, + } + } +} + +// 查找或创建供应商 +func getOrCreateVendor(vendorName string, vendorMap map[int]*Vendor) int { + // 查找现有供应商 + for id, vendor := range vendorMap { + if vendor.Name == vendorName { + return id + } + } + + // 创建新供应商 + newVendor := &Vendor{ + Name: vendorName, + Status: 1, + Icon: getDefaultVendorIcon(vendorName), + } + + if err := newVendor.Insert(); err != nil { + return 0 + } + + vendorMap[newVendor.Id] = newVendor + return newVendor.Id +} + +// 获取供应商默认图标 +func getDefaultVendorIcon(vendorName string) string { + if icon, exists := defaultVendorIcons[vendorName]; exists { + return icon + } + return "" +} diff --git a/model/task.go b/model/task.go index 9e4177ba0..4c64a5293 100644 --- a/model/task.go +++ b/model/task.go @@ -77,7 +77,7 @@ type SyncTaskQueryParams struct { UserIDs []int } -func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task { +func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.RelayInfo) *Task { t := &Task{ UserId: relayInfo.UserId, SubmitTime: time.Now().Unix(), diff --git a/model/twofa.go b/model/twofa.go index 8e97289f9..2a3d33530 100644 --- a/model/twofa.go +++ b/model/twofa.go @@ -16,7 +16,7 @@ type TwoFA struct { Id int `json:"id" gorm:"primaryKey"` UserId int `json:"user_id" gorm:"unique;not null;index"` Secret string `json:"-" gorm:"type:varchar(255);not null"` // TOTP密钥,不返回给前端 - IsEnabled bool `json:"is_enabled" gorm:"default:false"` + IsEnabled bool `json:"is_enabled"` FailedAttempts int `json:"failed_attempts" gorm:"default:0"` LockedUntil *time.Time `json:"locked_until,omitempty"` LastUsedAt *time.Time `json:"last_used_at,omitempty"` @@ -30,7 +30,7 @@ type TwoFABackupCode struct { Id int `json:"id" gorm:"primaryKey"` UserId int `json:"user_id" gorm:"not null;index"` CodeHash string `json:"-" gorm:"type:varchar(255);not null"` // 备用码哈希 - IsUsed bool `json:"is_used" gorm:"default:false"` + IsUsed bool `json:"is_used"` UsedAt *time.Time `json:"used_at,omitempty"` CreatedAt time.Time `json:"created_at"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index"` diff --git a/model/user.go b/model/user.go index 29d7a4462..ea0584c5a 100644 --- a/model/user.go +++ b/model/user.go @@ -91,6 +91,68 @@ func (user *User) SetSetting(setting dto.UserSetting) { user.Setting = string(settingBytes) } +// 根据用户角色生成默认的边栏配置 +func generateDefaultSidebarConfigForRole(userRole int) string { + defaultConfig := map[string]interface{}{} + + // 聊天区域 - 所有用户都可以访问 + defaultConfig["chat"] = map[string]interface{}{ + "enabled": true, + "playground": true, + "chat": true, + } + + // 控制台区域 - 所有用户都可以访问 + defaultConfig["console"] = map[string]interface{}{ + "enabled": true, + "detail": true, + "token": true, + "log": true, + "midjourney": true, + "task": true, + } + + // 个人中心区域 - 所有用户都可以访问 + defaultConfig["personal"] = map[string]interface{}{ + "enabled": true, + "topup": true, + "personal": true, + } + + // 管理员区域 - 根据角色决定 + if userRole == common.RoleAdminUser { + // 管理员可以访问管理员区域,但不能访问系统设置 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": false, // 管理员不能访问系统设置 + } + } else if userRole == common.RoleRootUser { + // 超级管理员可以访问所有功能 + defaultConfig["admin"] = map[string]interface{}{ + "enabled": true, + "channel": true, + "models": true, + "redemption": true, + "user": true, + "setting": true, + } + } + // 普通用户不包含admin区域 + + // 转换为JSON字符串 + configBytes, err := json.Marshal(defaultConfig) + if err != nil { + common.SysLog("生成默认边栏配置失败: " + err.Error()) + return "" + } + + return string(configBytes) +} + // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil func CheckUserExistOrDeleted(username string, email string) (bool, error) { var user User @@ -320,10 +382,34 @@ func (user *User) Insert(inviterId int) error { user.Quota = common.QuotaForNewUser //user.SetAccessToken(common.GetUUID()) user.AffCode = common.GetRandomString(4) + + // 初始化用户设置,包括默认的边栏配置 + if user.Setting == "" { + defaultSetting := dto.UserSetting{} + // 这里暂时不设置SidebarModules,因为需要在用户创建后根据角色设置 + user.SetSetting(defaultSetting) + } + result := DB.Create(user) if result.Error != nil { return result.Error } + + // 用户创建成功后,根据角色初始化边栏配置 + // 需要重新获取用户以确保有正确的ID和Role + var createdUser User + if err := DB.Where("username = ?", user.Username).First(&createdUser).Error; err == nil { + // 生成基于角色的默认边栏配置 + defaultSidebarConfig := generateDefaultSidebarConfigForRole(createdUser.Role) + if defaultSidebarConfig != "" { + currentSetting := createdUser.GetSetting() + currentSetting.SidebarModules = defaultSidebarConfig + createdUser.SetSetting(currentSetting) + createdUser.Update(false) + common.SysLog(fmt.Sprintf("为新用户 %s (角色: %d) 初始化边栏配置", createdUser.Username, createdUser.Role)) + } + } + if common.QuotaForNewUser > 0 { RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", logger.LogQuota(common.QuotaForNewUser))) } diff --git a/model/vendor_meta.go b/model/vendor_meta.go index 88439f249..20deaea9b 100644 --- a/model/vendor_meta.go +++ b/model/vendor_meta.go @@ -14,13 +14,13 @@ import ( type Vendor struct { Id int `json:"id"` - Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name,priority:1"` + Name string `json:"name" gorm:"size:128;not null;uniqueIndex:uk_vendor_name_delete_at,priority:1"` Description string `json:"description,omitempty" gorm:"type:text"` Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"` Status int `json:"status" gorm:"default:1"` CreatedTime int64 `json:"created_time" gorm:"bigint"` UpdatedTime int64 `json:"updated_time" gorm:"bigint"` - DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name,priority:2"` + DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_vendor_name_delete_at,priority:2"` } // Insert 创建新的供应商记录 diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 711cc7a9b..1357e3816 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -53,7 +53,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index ec7491334..02de99567 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -30,16 +30,16 @@ type Adaptor interface { } type TaskAdaptor interface { - Init(info *relaycommon.TaskRelayInfo) + Init(info *relaycommon.RelayInfo) - ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError + ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError - BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) - BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error - BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) + BuildRequestURL(info *relaycommon.RelayInfo) (string, error) + BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error + BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) - DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError) + DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, err *dto.TaskError) GetModelList() []string GetChannelName() string diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 518d25cea..a065caff7 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http } resp, err := client.Do(req) - if err != nil { - return nil, err + return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed")) } if resp == nil { return nil, errors.New("resp is nil") @@ -277,7 +276,7 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http return resp, nil } -func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.BuildRequestURL(info) if err != nil { return nil, err @@ -294,7 +293,7 @@ func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } - resp, err := doRequest(c, req, info.RelayInfo) + resp, err := doRequest(c, req, info) if err != nil { return nil, fmt.Errorf("do request failed: %w", err) } diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index 6744f8ba6..0577ebcb7 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -81,20 +81,23 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if strings.HasSuffix(info.UpstreamModelName, "-search") { info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search") request.Model = info.UpstreamModelName - toMap := request.ToMap() - toMap["web_search"] = map[string]any{ - "enable": true, - "enable_citation": true, - "enable_trace": true, - "enable_status": false, + if len(request.WebSearch) == 0 { + toMap := request.ToMap() + toMap["web_search"] = map[string]any{ + "enable": true, + "enable_citation": true, + "enable_trace": true, + "enable_status": false, + } + return toMap, nil } - return toMap, nil + return request, nil } return request, nil } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil + return nil, errors.New("not implemented") } func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0c445bb9a..511db2c6b 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -32,7 +32,7 @@ func stopReasonClaude2OpenAI(reason string) string { case "end_turn": return "stop" case "max_tokens": - return "max_tokens" + return "length" case "tool_use": return "tool_calls" default: @@ -274,19 +274,28 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe claudeMessages := make([]dto.ClaudeMessage, 0) isFirstMessage := true + // 初始化system消息数组,用于累积多个system消息 + var systemMessages []dto.ClaudeMediaMessage + for _, message := range formatMessages { if message.Role == "system" { + // 根据Claude API规范,system字段使用数组格式更有通用性 if message.IsStringContent() { - claudeRequest.System = message.StringContent() + systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](message.StringContent()), + }) } else { - contents := message.ParseContent() - content := "" - for _, ctx := range contents { + // 支持复合内容的system消息(虽然不常见,但需要考虑完整性) + for _, ctx := range message.ParseContent() { if ctx.Type == "text" { - content += ctx.Text + systemMessages = append(systemMessages, dto.ClaudeMediaMessage{ + Type: "text", + Text: common.GetPointer[string](ctx.Text), + }) } + // 未来可以在这里扩展对图片等其他类型的支持 } - claudeRequest.System = content } } else { if isFirstMessage { @@ -392,6 +401,12 @@ func RequestOpenAI2ClaudeMessage(c *gin.Context, textRequest dto.GeneralOpenAIRe claudeMessages = append(claudeMessages, claudeMessage) } } + + // 设置累积的system消息 + if len(systemMessages) > 0 { + claudeRequest.System = systemMessages + } + claudeRequest.Prompt = "" claudeRequest.Messages = claudeMessages return &claudeRequest, nil @@ -426,7 +441,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { if claudeResponse.ContentBlock != nil { - //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) + // 如果是文本块,尽可能发送首段文本(若存在) + if claudeResponse.ContentBlock.Type == "text" && claudeResponse.ContentBlock.Text != nil { + choice.Delta.SetContentString(*claudeResponse.ContentBlock.Text) + } if claudeResponse.ContentBlock.Type == "tool_use" { tools = append(tools, dto.ToolCallResponse{ Index: common.GetPointer(fcIdx), diff --git a/relay/channel/gemini/relay-gemini-native.go b/relay/channel/gemini/relay-gemini-native.go index 974a22f50..564b86908 100644 --- a/relay/channel/gemini/relay-gemini-native.go +++ b/relay/channel/gemini/relay-gemini-native.go @@ -46,6 +46,32 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount + if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { + imageOutputCounts := 0 + for _, candidate := range geminiResponse.Candidates { + for _, part := range candidate.Content.Parts { + if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") { + imageOutputCounts++ + } + } + } + if imageOutputCounts != 0 { + usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290 + usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290 + c.Set("gemini_image_tokens", imageOutputCounts*1290) + } + } + + // if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { + // for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails { + // if detail.Modality == "IMAGE" { + // usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount + // usage.TotalTokens = usage.TotalTokens - detail.TokenCount + // c.Set("gemini_image_tokens", detail.TokenCount) + // } + // } + // } + for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails { if detail.Modality == "AUDIO" { usage.PromptTokensDetails.AudioTokens = detail.TokenCount @@ -136,6 +162,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn usage.PromptTokensDetails.TextTokens = detail.TokenCount } } + + if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") { + for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails { + if detail.Modality == "IMAGE" { + usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount + usage.TotalTokens = usage.TotalTokens - detail.TokenCount + c.Set("gemini_image_tokens", detail.TokenCount) + } + } + } } // 直接发送 GeminiChatResponse 响应 diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index c54eb5b6b..eb4afbae1 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -749,7 +749,16 @@ func responseGeminiChat2OpenAI(c *gin.Context, response *dto.GeminiChatResponse) var texts []string var toolCalls []dto.ToolCallResponse for _, part := range candidate.Content.Parts { - if part.FunctionCall != nil { + if part.InlineData != nil { + // 媒体内容 + if strings.HasPrefix(part.InlineData.MimeType, "image") { + imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")" + texts = append(texts, imgText) + } else { + // 其他媒体类型,直接显示链接 + texts = append(texts, fmt.Sprintf("[media](data:%s;base64,%s)", part.InlineData.MimeType, part.InlineData.Data)) + } + } else if part.FunctionCall != nil { choice.FinishReason = constant.FinishReasonToolCalls if call := getResponseToolCall(&part); call != nil { toolCalls = append(toolCalls, *call) diff --git a/relay/channel/mokaai/constants.go b/relay/channel/mokaai/constants.go index 415d83b7f..385a0876b 100644 --- a/relay/channel/mokaai/constants.go +++ b/relay/channel/mokaai/constants.go @@ -6,4 +6,4 @@ var ModelList = []string{ "m3e-small", } -var ChannelName = "mokaai" \ No newline at end of file +var ChannelName = "mokaai" diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 3756e677b..1d8286a43 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -537,8 +537,14 @@ func detectImageMimeType(filename string) string { func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { // 转换模型推理力度后缀 effort, originModel := parseReasoningEffortFromModelSuffix(request.Model) - if effort != "" && request.Reasoning != nil { - request.Reasoning.Effort = effort + if effort != "" { + if request.Reasoning == nil { + request.Reasoning = &dto.Reasoning{ + Effort: effort, + } + } else { + request.Reasoning.Effort = effort + } request.Model = originModel } return request, nil diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index cce9235b5..4b13a7df1 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -2,6 +2,7 @@ package openai import ( "bytes" + "encoding/json" "fmt" "io" "math" @@ -280,11 +281,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) { defer service.CloseResponseBodyGracefully(resp) - // count tokens by audio file duration - audioTokens, err := countAudioTokens(c) - if err != nil { - return types.NewError(err, types.ErrorCodeCountTokenFailed), nil - } responseBody, err := io.ReadAll(resp.Body) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil @@ -292,6 +288,26 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel // 写入新的 response body service.IOCopyBytesGracefully(c, resp, responseBody) + var responseData struct { + Usage *dto.Usage `json:"usage"` + } + if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil { + if responseData.Usage.TotalTokens > 0 { + usage := responseData.Usage + if usage.PromptTokens == 0 { + usage.PromptTokens = usage.InputTokens + } + if usage.CompletionTokens == 0 { + usage.CompletionTokens = usage.OutputTokens + } + return nil, usage + } + } + + audioTokens, err := countAudioTokens(c) + if err != nil { + return types.NewError(err, types.ErrorCodeCountTokenFailed), nil + } usage := &dto.Usage{} usage.PromptTokens = audioTokens usage.CompletionTokens = 0 diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index ab2aa8a4a..e188889e4 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -46,9 +46,17 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http usage.PromptTokensDetails.CachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens } } + if info == nil || info.ResponsesUsageInfo == nil || info.ResponsesUsageInfo.BuiltInTools == nil { + return &usage, nil + } // 解析 Tools 用量 for _, tool := range responsesResponse.Tools { - info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++ + buildToolinfo, ok := info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])] + if !ok || buildToolinfo == nil { + logger.LogError(c, fmt.Sprintf("BuiltInTools not found for tool type: %v", tool["type"])) + continue + } + buildToolinfo.CallCount++ } return &usage, nil } @@ -72,10 +80,16 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp sendResponsesStreamData(c, streamResponse, data) switch streamResponse.Type { case "response.completed": - if streamResponse.Response.Usage != nil { - usage.PromptTokens = streamResponse.Response.Usage.InputTokens - usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens - usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + if streamResponse.Response != nil && streamResponse.Response.Usage != nil { + if streamResponse.Response.Usage.InputTokens != 0 { + usage.PromptTokens = streamResponse.Response.Usage.InputTokens + } + if streamResponse.Response.Usage.OutputTokens != 0 { + usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens + } + if streamResponse.Response.Usage.TotalTokens != 0 { + usage.TotalTokens = streamResponse.Response.Usage.TotalTokens + } if streamResponse.Response.Usage.InputTokensDetails != nil { usage.PromptTokensDetails.CachedTokens = streamResponse.Response.Usage.InputTokensDetails.CachedTokens } @@ -92,6 +106,8 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } } } + } else { + logger.LogError(c, "failed to unmarshal stream response: "+err.Error()) } return true }) @@ -107,10 +123,10 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp } if usage.PromptTokens == 0 && usage.CompletionTokens != 0 { - usage.PromptTokens = usage.CompletionTokens - } else { - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + usage.PromptTokens = info.PromptTokens } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage, nil } diff --git a/relay/channel/task/jimeng/adaptor.go b/relay/channel/task/jimeng/adaptor.go index a5ada1370..955e592a2 100644 --- a/relay/channel/task/jimeng/adaptor.go +++ b/relay/channel/task/jimeng/adaptor.go @@ -74,7 +74,7 @@ type TaskAdaptor struct { baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl @@ -87,7 +87,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { } // ValidateRequestAndSetAction parses body, validates fields and sets default action. -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. action := constant.TaskActionGenerate info.Action = action @@ -108,19 +108,19 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } // BuildRequestURL constructs the upstream URL. -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil } // BuildRequestHeader sets required headers. -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") return a.signRequest(req, a.accessKey, a.secretKey) } // BuildRequestBody converts request into Jimeng specific format. -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -139,12 +139,12 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel } // DoRequest delegates to common helper. -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/channel/task/kling/adaptor.go b/relay/channel/task/kling/adaptor.go index 1fecda08a..3d6da253b 100644 --- a/relay/channel/task/kling/adaptor.go +++ b/relay/channel/task/kling/adaptor.go @@ -4,13 +4,14 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/samber/lo" "io" "net/http" "one-api/model" "strings" "time" + "github.com/samber/lo" + "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "github.com/pkg/errors" @@ -37,15 +38,46 @@ type SubmitReq struct { Metadata map[string]interface{} `json:"metadata,omitempty"` } +type TrajectoryPoint struct { + X int `json:"x"` + Y int `json:"y"` +} + +type DynamicMask struct { + Mask string `json:"mask,omitempty"` + Trajectories []TrajectoryPoint `json:"trajectories,omitempty"` +} + +type CameraConfig struct { + Horizontal float64 `json:"horizontal,omitempty"` + Vertical float64 `json:"vertical,omitempty"` + Pan float64 `json:"pan,omitempty"` + Tilt float64 `json:"tilt,omitempty"` + Roll float64 `json:"roll,omitempty"` + Zoom float64 `json:"zoom,omitempty"` +} + +type CameraControl struct { + Type string `json:"type,omitempty"` + Config *CameraConfig `json:"config,omitempty"` +} + type requestPayload struct { - Prompt string `json:"prompt,omitempty"` - Image string `json:"image,omitempty"` - Mode string `json:"mode,omitempty"` - Duration string `json:"duration,omitempty"` - AspectRatio string `json:"aspect_ratio,omitempty"` - ModelName string `json:"model_name,omitempty"` - Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model" - CfgScale float64 `json:"cfg_scale,omitempty"` + Prompt string `json:"prompt,omitempty"` + Image string `json:"image,omitempty"` + ImageTail string `json:"image_tail,omitempty"` + NegativePrompt string `json:"negative_prompt,omitempty"` + Mode string `json:"mode,omitempty"` + Duration string `json:"duration,omitempty"` + AspectRatio string `json:"aspect_ratio,omitempty"` + ModelName string `json:"model_name,omitempty"` + Model string `json:"model,omitempty"` // Compatible with upstreams that only recognize "model" + CfgScale float64 `json:"cfg_scale,omitempty"` + StaticMask string `json:"static_mask,omitempty"` + DynamicMasks []DynamicMask `json:"dynamic_masks,omitempty"` + CameraControl *CameraControl `json:"camera_control,omitempty"` + CallbackUrl string `json:"callback_url,omitempty"` + ExternalTaskId string `json:"external_task_id,omitempty"` } type responsePayload struct { @@ -79,7 +111,7 @@ type TaskAdaptor struct { baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl a.apiKey = info.ApiKey @@ -88,7 +120,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { } // ValidateRequestAndSetAction parses body, validates fields and sets default action. -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Accept only POST /v1/video/generations as "generate" action. action := constant.TaskActionGenerate info.Action = action @@ -109,13 +141,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom } // BuildRequestURL constructs the upstream URL. -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video") return fmt.Sprintf("%s%s", a.baseURL, path), nil } // BuildRequestHeader sets required headers. -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { token, err := a.createJWTToken() if err != nil { return fmt.Errorf("failed to create JWT token: %w", err) @@ -129,7 +161,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info } // BuildRequestBody converts request into Kling specific format. -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -140,6 +172,9 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel if err != nil { return nil, err } + if body.Image == "" && body.ImageTail == "" { + c.Set("action", constant.TaskActionTextGenerate) + } data, err := json.Marshal(body) if err != nil { return nil, err @@ -148,7 +183,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel } // DoRequest delegates to common helper. -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { if action := c.GetString("action"); action != "" { info.Action = action } @@ -156,7 +191,7 @@ func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, } // DoResponse handles upstream response, returns taskID etc. -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -222,14 +257,19 @@ func (a *TaskAdaptor) GetChannelName() string { func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) { r := requestPayload{ - Prompt: req.Prompt, - Image: req.Image, - Mode: defaultString(req.Mode, "std"), - Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), - AspectRatio: a.getAspectRatio(req.Size), - ModelName: req.Model, - Model: req.Model, // Keep consistent with model_name, double writing improves compatibility - CfgScale: 0.5, + Prompt: req.Prompt, + Image: req.Image, + Mode: defaultString(req.Mode, "std"), + Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)), + AspectRatio: a.getAspectRatio(req.Size), + ModelName: req.Model, + Model: req.Model, // Keep consistent with model_name, double writing improves compatibility + CfgScale: 0.5, + StaticMask: "", + DynamicMasks: []DynamicMask{}, + CameraControl: nil, + CallbackUrl: "", + ExternalTaskId: "", } if r.ModelName == "" { r.ModelName = "kling-v1" diff --git a/relay/channel/task/suno/adaptor.go b/relay/channel/task/suno/adaptor.go index df2bb99ea..237513d75 100644 --- a/relay/channel/task/suno/adaptor.go +++ b/relay/channel/task/suno/adaptor.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -16,6 +15,8 @@ import ( "one-api/service" "strings" "time" + + "github.com/gin-gonic/gin" ) type TaskAdaptor struct { @@ -26,11 +27,11 @@ func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, fmt.Errorf("not implement") // todo implement this method if needed } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { action := strings.ToUpper(c.Param("action")) var sunoRequest *dto.SunoSubmitReq @@ -58,20 +59,20 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return nil } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { baseURL := info.ChannelBaseUrl fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action) return fullRequestURL, nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Authorization", "Bearer "+info.ApiKey) return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { sunoRequest, ok := c.Get("task_request") if !ok { err := common.UnmarshalBodyReusable(c, &sunoRequest) @@ -86,11 +87,11 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRel return bytes.NewReader(data), nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/channel/task/vidu/adaptor.go b/relay/channel/task/vidu/adaptor.go index b0cc0bdc8..c82c1c0e8 100644 --- a/relay/channel/task/vidu/adaptor.go +++ b/relay/channel/task/vidu/adaptor.go @@ -84,12 +84,12 @@ type TaskAdaptor struct { baseURL string } -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl } -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError { +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { var req SubmitReq if err := c.ShouldBindJSON(&req); err != nil { return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest) @@ -109,7 +109,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) { v, exists := c.Get("task_request") if !exists { return nil, fmt.Errorf("request not found in context") @@ -132,7 +132,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayI return bytes.NewReader(data), nil } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { var path string switch info.Action { case constant.TaskActionGenerate: @@ -143,21 +143,21 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, return fmt.Sprintf("%s/ent/v2%s", a.baseURL, path), nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Token "+info.ApiKey) return nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { if action := c.GetString("action"); action != "" { info.Action = action } return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index b46cb9525..0af019da4 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -2,6 +2,7 @@ package volcengine import ( "bytes" + "encoding/json" "errors" "fmt" "io" @@ -214,6 +215,12 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } + // 适配 方舟deepseek混合模型 的 thinking 后缀 + if strings.HasSuffix(info.UpstreamModelName, "-thinking") && strings.HasPrefix(info.UpstreamModelName, "deepseek") { + info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking") + request.Model = info.UpstreamModelName + request.THINKING = json.RawMessage(`{"type": "enabled"}`) + } return request, nil } diff --git a/relay/claude_handler.go b/relay/claude_handler.go index 59c052f62..dbdc6ee1c 100644 --- a/relay/claude_handler.go +++ b/relay/claude_handler.go @@ -111,7 +111,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/common/override.go b/relay/common/override.go index c8f216ed5..212cf7b47 100644 --- a/relay/common/override.go +++ b/relay/common/override.go @@ -5,6 +5,8 @@ import ( "fmt" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "regexp" + "strconv" "strings" ) @@ -151,7 +153,9 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri } func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) { - value := gjson.Get(jsonStr, condition.Path) + // 处理负数索引 + path := processNegativeIndex(jsonStr, condition.Path) + value := gjson.Get(jsonStr, path) if !value.Exists() { if condition.PassMissingKey { return true, nil @@ -177,6 +181,37 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e return result, nil } +func processNegativeIndex(jsonStr string, path string) string { + re := regexp.MustCompile(`\.(-\d+)`) + matches := re.FindAllStringSubmatch(path, -1) + + if len(matches) == 0 { + return path + } + + result := path + for _, match := range matches { + negIndex := match[1] + index, _ := strconv.Atoi(negIndex) + + arrayPath := strings.Split(path, negIndex)[0] + if strings.HasSuffix(arrayPath, ".") { + arrayPath = arrayPath[:len(arrayPath)-1] + } + + array := gjson.Get(jsonStr, arrayPath) + if array.IsArray() { + length := len(array.Array()) + actualIndex := length + index + if actualIndex >= 0 && actualIndex < length { + result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1) + } + } + } + + return result +} + // compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式 func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) { switch mode { @@ -274,21 +309,25 @@ func applyOperations(jsonStr string, operations []ParamOperation) (string, error if !ok { continue // 条件不满足,跳过当前操作 } + // 处理路径中的负数索引 + opPath := processNegativeIndex(result, op.Path) + opFrom := processNegativeIndex(result, op.From) + opTo := processNegativeIndex(result, op.To) switch op.Mode { case "delete": - result, err = sjson.Delete(result, op.Path) + result, err = sjson.Delete(result, opPath) case "set": - if op.KeepOrigin && gjson.Get(result, op.Path).Exists() { + if op.KeepOrigin && gjson.Get(result, opPath).Exists() { continue } - result, err = sjson.Set(result, op.Path, op.Value) + result, err = sjson.Set(result, opPath, op.Value) case "move": - result, err = moveValue(result, op.From, op.To) + result, err = moveValue(result, opFrom, opTo) case "prepend": - result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, true) + result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true) case "append": - result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, false) + result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false) default: return "", fmt.Errorf("unknown operation: %s", op.Mode) } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index caf8b452e..da572c070 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -116,6 +116,7 @@ type RelayInfo struct { *RerankerInfo *ResponsesUsageInfo *ChannelMeta + *TaskRelayInfo } func (info *RelayInfo) InitChannelMeta(c *gin.Context) { @@ -313,7 +314,7 @@ func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest) BuiltInTools: make(map[string]*BuildInToolInfo), } if len(request.Tools) > 0 { - for _, tool := range request.Tools { + for _, tool := range request.GetToolsMap() { toolType := common.Interface2String(tool["type"]) info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{ ToolName: toolType, @@ -400,6 +401,10 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo { }, } + if info.RelayMode == relayconstant.RelayModeUnknown { + info.RelayMode = c.GetInt("relay_mode") + } + if strings.HasPrefix(c.Request.URL.Path, "/pg") { info.IsPlayground = true info.RequestURLPath = strings.TrimPrefix(info.RequestURLPath, "/pg") @@ -465,25 +470,12 @@ func (info *RelayInfo) HasSendResponse() bool { } type TaskRelayInfo struct { - *RelayInfo Action string OriginTaskID string ConsumeQuota bool } -func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) { - relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil) - if err != nil { - return nil, err - } - info := &TaskRelayInfo{ - RelayInfo: relayInfo, - } - info.InitChannelMeta(c) - return info, nil -} - type TaskSubmitReq struct { Prompt string `json:"prompt"` Model string `json:"model,omitempty"` diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 290865854..3d5efcb6d 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -2,12 +2,10 @@ package common import ( "fmt" - "github.com/gin-gonic/gin" - _ "image/gif" - _ "image/jpeg" - _ "image/png" "one-api/constant" "strings" + + "github.com/gin-gonic/gin" ) func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index 56d65a3f3..8f27fd60b 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -130,7 +130,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types jsonData, err := common.Marshal(convertedRequest) if err != nil { - return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) + return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry()) } // apply param override @@ -158,7 +158,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newApiErr := service.RelayErrorHandler(httpResp, false) + newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr @@ -195,6 +195,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage imageTokens := usage.PromptTokensDetails.ImageTokens audioTokens := usage.PromptTokensDetails.AudioTokens completionTokens := usage.CompletionTokens + cachedCreationTokens := usage.PromptTokensDetails.CachedCreationTokens + modelName := relayInfo.OriginModelName tokenName := ctx.GetString("token_name") @@ -204,6 +206,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage modelRatio := relayInfo.PriceData.ModelRatio groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio modelPrice := relayInfo.PriceData.ModelPrice + cachedCreationRatio := relayInfo.PriceData.CacheCreationRatio // Convert values to decimal for precise calculation dPromptTokens := decimal.NewFromInt(int64(promptTokens)) @@ -211,12 +214,14 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage dImageTokens := decimal.NewFromInt(int64(imageTokens)) dAudioTokens := decimal.NewFromInt(int64(audioTokens)) dCompletionTokens := decimal.NewFromInt(int64(completionTokens)) + dCachedCreationTokens := decimal.NewFromInt(int64(cachedCreationTokens)) dCompletionRatio := decimal.NewFromFloat(completionRatio) dCacheRatio := decimal.NewFromFloat(cacheRatio) dImageRatio := decimal.NewFromFloat(imageRatio) dModelRatio := decimal.NewFromFloat(modelRatio) dGroupRatio := decimal.NewFromFloat(groupRatio) dModelPrice := decimal.NewFromFloat(modelPrice) + dCachedCreationRatio := decimal.NewFromFloat(cachedCreationRatio) dQuotaPerUnit := decimal.NewFromFloat(common.QuotaPerUnit) ratio := dModelRatio.Mul(dGroupRatio) @@ -284,6 +289,11 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage baseTokens = baseTokens.Sub(dCacheTokens) cachedTokensWithRatio = dCacheTokens.Mul(dCacheRatio) } + var dCachedCreationTokensWithRatio decimal.Decimal + if !dCachedCreationTokens.IsZero() { + baseTokens = baseTokens.Sub(dCachedCreationTokens) + dCachedCreationTokensWithRatio = dCachedCreationTokens.Mul(dCachedCreationRatio) + } // 减去 image tokens var imageTokensWithRatio decimal.Decimal @@ -302,7 +312,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()) } } - promptQuota := baseTokens.Add(cachedTokensWithRatio).Add(imageTokensWithRatio) + promptQuota := baseTokens.Add(cachedTokensWithRatio). + Add(imageTokensWithRatio). + Add(dCachedCreationTokensWithRatio) completionQuota := dCompletionTokens.Mul(dCompletionRatio) @@ -314,11 +326,22 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage } else { quotaCalculateDecimal = dModelPrice.Mul(dQuotaPerUnit).Mul(dGroupRatio) } + var dGeminiImageOutputQuota decimal.Decimal + var imageOutputPrice float64 + if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") { + imageOutputPrice = operation_setting.GetGeminiImageOutputPricePerMillionTokens(modelName) + if imageOutputPrice > 0 { + dImageOutputTokens := decimal.NewFromInt(int64(ctx.GetInt("gemini_image_tokens"))) + dGeminiImageOutputQuota = decimal.NewFromFloat(imageOutputPrice).Div(decimal.NewFromInt(1000000)).Mul(dImageOutputTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit) + } + } // 添加 responses tools call 调用的配额 quotaCalculateDecimal = quotaCalculateDecimal.Add(dWebSearchQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(dFileSearchQuota) // 添加 audio input 独立计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(audioInputQuota) + // 添加 Gemini image output 计费 + quotaCalculateDecimal = quotaCalculateDecimal.Add(dGeminiImageOutputQuota) quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens @@ -384,6 +407,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["image_ratio"] = imageRatio other["image_output"] = imageTokens } + if cachedCreationTokens != 0 { + other["cache_creation_tokens"] = cachedCreationTokens + other["cache_creation_ratio"] = cachedCreationRatio + } if !dWebSearchQuota.IsZero() { if relayInfo.ResponsesUsageInfo != nil { if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists { @@ -413,6 +440,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage other["audio_input_token_count"] = audioTokens other["audio_input_price"] = audioInputPrice } + if !dGeminiImageOutputQuota.IsZero() { + other["image_output_token_count"] = ctx.GetInt("gemini_image_tokens") + other["image_output_price"] = imageOutputPrice + } model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ ChannelId: relayInfo.ChannelId, PromptTokens: promptTokens, diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 26dcf9719..3d8962bb4 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -58,7 +58,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index 460fd2f58..0252d6578 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -152,7 +152,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError @@ -249,7 +249,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } diff --git a/relay/helper/common.go b/relay/helper/common.go index 5b3e76743..381147ae5 100644 --- a/relay/helper/common.go +++ b/relay/helper/common.go @@ -1,7 +1,6 @@ package helper import ( - "encoding/json" "errors" "fmt" "net/http" @@ -42,7 +41,7 @@ func SetEventStreamHeaders(c *gin.Context) { } func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error { - jsonData, err := json.Marshal(resp) + jsonData, err := common.Marshal(resp) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) } else { @@ -104,7 +103,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error { } func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { - jsonData, err := json.Marshal(object) + jsonData, err := common.Marshal(object) if err != nil { return fmt.Errorf("error marshalling object: %w", err) } diff --git a/relay/image_handler.go b/relay/image_handler.go index 14a7103c3..9c873d47f 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -91,7 +91,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError @@ -120,7 +120,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type var logContent string if len(request.Size) > 0 { - logContent = fmt.Sprintf("大小 %s, 品质 %s", request.Size, quality) + logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N) } postConsumeQuota(c, info, usage.(*dto.Usage), logContent) diff --git a/relay/relay_task.go b/relay/relay_task.go index 6faec176d..9cb8cd5c8 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -24,32 +24,32 @@ import ( /* Task 任务通过平台、Action 区分任务 */ -func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { +func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + info.InitChannelMeta(c) + // ensure TaskRelayInfo is initialized to avoid nil dereference when accessing embedded fields + if info.TaskRelayInfo == nil { + info.TaskRelayInfo = &relaycommon.TaskRelayInfo{} + } platform := constant.TaskPlatform(c.GetString("platform")) if platform == "" { platform = GetTaskPlatform(c) } - relayInfo, err := relaycommon.GenTaskRelayInfo(c) - if err != nil { - return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError) - } - relayInfo.InitChannelMeta(c) - + info.InitChannelMeta(c) adaptor := GetTaskAdaptor(platform) if adaptor == nil { return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) } - adaptor.Init(relayInfo) + adaptor.Init(info) // get & validate taskRequest 获取并验证文本请求 - taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo) + taskErr = adaptor.ValidateRequestAndSetAction(c, info) if taskErr != nil { return } - modelName := relayInfo.OriginModelName + modelName := info.OriginModelName if modelName == "" { - modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action) + modelName = service.CoverTaskActionToModelName(platform, info.Action) } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { @@ -62,15 +62,15 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { } // 预扣 - groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup) + groupRatio := ratio_setting.GetGroupRatio(info.UsingGroup) var ratio float64 - userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(info.UserGroup, info.UsingGroup) if hasUserGroupRatio { ratio = modelPrice * userGroupRatio } else { ratio = modelPrice * groupRatio } - userQuota, err := model.GetUserQuota(relayInfo.UserId, false) + userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) return @@ -81,8 +81,8 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { return } - if relayInfo.OriginTaskID != "" { - originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID) + if info.OriginTaskID != "" { + originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) return @@ -91,7 +91,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) return } - if originTask.ChannelId != relayInfo.ChannelId { + if originTask.ChannelId != info.ChannelId { channel, err := model.GetChannelById(originTask.ChannelId, true) if err != nil { taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) @@ -104,19 +104,19 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { c.Set("channel_id", originTask.ChannelId) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) - relayInfo.ChannelBaseUrl = channel.GetBaseURL() - relayInfo.ChannelId = originTask.ChannelId + info.ChannelBaseUrl = channel.GetBaseURL() + info.ChannelId = originTask.ChannelId } } // build body - requestBody, err := adaptor.BuildRequestBody(c, relayInfo) + requestBody, err := adaptor.BuildRequestBody(c, info) if err != nil { taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) return } // do request - resp, err := adaptor.DoRequest(c, relayInfo, requestBody) + resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) return @@ -130,9 +130,9 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { defer func() { // release quota - if relayInfo.ConsumeQuota && taskErr == nil { + if info.ConsumeQuota && taskErr == nil { - err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true) + err := service.PostConsumeQuota(info, quota, 0, true) if err != nil { common.SysLog("error consuming token remain quota: " + err.Error()) } @@ -142,40 +142,40 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) { if hasUserGroupRatio { gRatio = userGroupRatio } - logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action) + logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action) other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio if hasUserGroupRatio { other["user_group_ratio"] = userGroupRatio } - model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{ - ChannelId: relayInfo.ChannelId, + model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ + ChannelId: info.ChannelId, ModelName: modelName, TokenName: tokenName, Quota: quota, Content: logContent, - TokenId: relayInfo.TokenId, - Group: relayInfo.UsingGroup, + TokenId: info.TokenId, + Group: info.UsingGroup, Other: other, }) - model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) - model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + model.UpdateUserUsedQuotaAndRequestCount(info.UserId, quota) + model.UpdateChannelUsedQuota(info.ChannelId, quota) } } }() - taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo) + taskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) if taskErr != nil { return } - relayInfo.ConsumeQuota = true + info.ConsumeQuota = true // insert task - task := model.InitTask(platform, relayInfo) + task := model.InitTask(platform, info) task.TaskID = taskID task.Quota = quota task.Data = taskData - task.Action = relayInfo.Action + task.Action = info.Action err = task.Insert() if err != nil { taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError) diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index fa3c7bbb4..46d2e25f6 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -81,7 +81,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/relay/responses_handler.go b/relay/responses_handler.go index f5f624c92..d1c5d2158 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -82,7 +82,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - newAPIError = service.RelayErrorHandler(httpResp, false) + newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError diff --git a/router/api-router.go b/router/api-router.go index be721b05f..773857385 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -114,6 +114,7 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/models", controller.ChannelListModels) channelRoute.GET("/models_enabled", controller.EnabledListModels) channelRoute.GET("/:id", controller.GetChannel) + channelRoute.POST("/:id/key", middleware.CriticalRateLimit(), middleware.DisableCache(), controller.GetChannelKey) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) channelRoute.GET("/update_balance", controller.UpdateAllChannelsBalance) @@ -223,6 +224,8 @@ func SetApiRouter(router *gin.Engine) { modelsRoute := apiRouter.Group("/models") modelsRoute.Use(middleware.AdminAuth()) { + modelsRoute.GET("/sync_upstream/preview", controller.SyncUpstreamPreview) + modelsRoute.POST("/sync_upstream", controller.SyncUpstreamModels) modelsRoute.GET("/missing", controller.GetMissingModels) modelsRoute.GET("/", controller.GetAllModelsMeta) modelsRoute.GET("/search", controller.SearchModelsMeta) diff --git a/service/convert.go b/service/convert.go index ea219c4fa..b232ca396 100644 --- a/service/convert.go +++ b/service/convert.go @@ -248,9 +248,10 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon }, }) claudeResponses = append(claudeResponses, &dto.ClaudeResponse{ - Type: "content_block_delta", + Index: &info.ClaudeConvertInfo.Index, + Type: "content_block_delta", Delta: &dto.ClaudeMediaMessage{ - Type: "text", + Type: "text_delta", Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()), }, }) diff --git a/service/error.go b/service/error.go index ef5cbbde6..5c3bddd6e 100644 --- a/service/error.go +++ b/service/error.go @@ -1,12 +1,14 @@ package service import ( + "context" "errors" "fmt" "io" "net/http" "one-api/common" "one-api/dto" + "one-api/logger" "one-api/types" "strconv" "strings" @@ -78,7 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude return claudeErr } -func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { +func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode) responseBody, err := io.ReadAll(resp.Body) @@ -94,7 +96,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) } else { if common.DebugEnabled { - println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) + logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) } newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) } diff --git a/service/file_decoder.go b/service/file_decoder.go index 94f3f0282..99fdc3f9a 100644 --- a/service/file_decoder.go +++ b/service/file_decoder.go @@ -5,6 +5,9 @@ import ( "encoding/base64" "fmt" "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "io" "net/http" "one-api/common" diff --git a/service/image.go b/service/image.go index 252093f1f..453d8dd1c 100644 --- a/service/image.go +++ b/service/image.go @@ -21,6 +21,10 @@ func DecodeBase64ImageData(base64String string) (image.Config, string, string, e base64String = base64String[idx+1:] } + if len(base64String) == 0 { + return image.Config{}, "", "", errors.New("base64 string is empty") + } + // 将base64字符串解码为字节切片 decodedData, err := base64.StdEncoding.DecodeString(base64String) if err != nil { diff --git a/service/pre_consume_quota.go b/service/pre_consume_quota.go index 08e3f68f2..3cfabc1a4 100644 --- a/service/pre_consume_quota.go +++ b/service/pre_consume_quota.go @@ -1,7 +1,6 @@ package service import ( - "errors" "fmt" "net/http" "one-api/common" @@ -14,13 +13,13 @@ import ( "github.com/gin-gonic/gin" ) -func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) { - if preConsumedQuota != 0 { - logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota))) +func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo) { + if relayInfo.FinalPreConsumedQuota != 0 { + logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(relayInfo.FinalPreConsumedQuota))) gopool.Go(func() { relayInfoCopy := *relayInfo - err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false) + err := PostConsumeQuota(&relayInfoCopy, -relayInfo.FinalPreConsumedQuota, 0, false) if err != nil { common.SysLog("error return pre-consumed quota: " + err.Error()) } @@ -30,16 +29,16 @@ func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, pr // PreConsumeQuota checks if the user has enough quota to pre-consume. // It returns the pre-consumed quota if successful, or an error if not. -func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) { +func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) *types.NewAPIError { userQuota, err := model.GetUserQuota(relayInfo.UserId, false) if err != nil { - return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) + return types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry()) } if userQuota <= 0 { - return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(fmt.Errorf("用户额度不足, 剩余额度: %s", logger.FormatQuota(userQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } if userQuota-preConsumedQuota < 0 { - return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } trustQuota := common.GetTrustQuota() @@ -66,14 +65,14 @@ func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if preConsumedQuota > 0 { err := PreConsumeTokenQuota(relayInfo, preConsumedQuota) if err != nil { - return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) + return types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog()) } err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { - return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) + return types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry()) } logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota))) } relayInfo.FinalPreConsumedQuota = preConsumedQuota - return preConsumedQuota, nil + return nil } diff --git a/service/quota.go b/service/quota.go index 8f65bd20e..e078a1ad1 100644 --- a/service/quota.go +++ b/service/quota.go @@ -535,8 +535,27 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon if quotaTooLow { prompt := "您的额度即将用尽" topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) - content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" - err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink})) + + // 根据通知方式生成不同的内容格式 + var content string + var values []interface{} + + notifyType := userSetting.NotifyType + if notifyType == "" { + notifyType = dto.NotifyTypeEmail + } + + if notifyType == dto.NotifyTypeBark { + // Bark推送使用简短文本,不支持HTML + content = "{{value}},剩余额度:{{value}},请及时充值" + values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota)} + } else { + // 默认内容格式,适用于Email和Webhook + content = "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。
充值链接:{{value}}" + values = []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink} + } + + err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, values)) if err != nil { common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error())) } diff --git a/service/token_counter.go b/service/token_counter.go index bac6c067b..da56523fe 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -5,6 +5,9 @@ import ( "errors" "fmt" "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "log" "math" "one-api/common" @@ -250,13 +253,18 @@ func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, er } func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) { - if meta == nil { - return 0, errors.New("token count meta is nil") + if !constant.GetMediaToken { + return 0, nil + } + if !constant.GetMediaTokenNotStream && !info.IsStream { + return 0, nil } - if info.RelayFormat == types.RelayFormatOpenAIRealtime { return 0, nil } + if meta == nil { + return 0, errors.New("token count meta is nil") + } model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel) tkm := 0 @@ -276,7 +284,7 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco shouldFetchFiles := true - if info.RelayFormat == types.RelayFormatOpenAIRealtime || info.RelayFormat == types.RelayFormatGemini { + if info.RelayFormat == types.RelayFormatGemini { shouldFetchFiles = false } @@ -297,19 +305,43 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco file.FileType = types.FileTypeFile } file.MimeType = mineType + } else if strings.HasPrefix(file.OriginData, "data:") { + // get mime type from base64 header + parts := strings.SplitN(file.OriginData, ",", 2) + if len(parts) >= 1 { + header := parts[0] + // Extract mime type from "data:mime/type;base64" format + if strings.Contains(header, ":") && strings.Contains(header, ";") { + mimeStart := strings.Index(header, ":") + 1 + mimeEnd := strings.Index(header, ";") + if mimeStart < mimeEnd { + mineType := header[mimeStart:mimeEnd] + if strings.HasPrefix(mineType, "image/") { + file.FileType = types.FileTypeImage + } else if strings.HasPrefix(mineType, "video/") { + file.FileType = types.FileTypeVideo + } else if strings.HasPrefix(mineType, "audio/") { + file.FileType = types.FileTypeAudio + } else { + file.FileType = types.FileTypeFile + } + file.MimeType = mineType + } + } + } } } } - for _, file := range meta.Files { + for i, file := range meta.Files { switch file.FileType { case types.FileTypeImage: - if info.RelayFormat == types.RelayFormatGemini { + if info.RelayFormat == types.RelayFormatGemini && !strings.HasPrefix(model, "gemini-2.5-flash-image-preview") { tkm += 256 } else { token, err := getImageToken(file, model, info.IsStream) if err != nil { - return 0, fmt.Errorf("error counting image token: %v", err) + return 0, fmt.Errorf("error counting image token, media index[%d], original data[%s], err: %v", i, file.OriginData, err) } tkm += token } @@ -328,33 +360,6 @@ func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relayco return tkm, nil } -//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) { -// tkm := 0 -// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream) -// if err != nil { -// return 0, err -// } -// tkm += msgTokens -// if request.Tools != nil { -// openaiTools := request.Tools -// countStr := "" -// for _, tool := range openaiTools { -// countStr = tool.Function.Name -// if tool.Function.Description != "" { -// countStr += tool.Function.Description -// } -// if tool.Function.Parameters != nil { -// countStr += fmt.Sprintf("%v", tool.Function.Parameters) -// } -// } -// toolTokens := CountTokenInput(countStr, request.Model) -// tkm += 8 -// tkm += toolTokens -// } -// -// return tkm, nil -//} - func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) { tkm := 0 @@ -514,56 +519,6 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return textToken, audioToken, nil } -//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) { -// //recover when panic -// tokenEncoder := getTokenEncoder(model) -// // Reference: -// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb -// // https://github.com/pkoukk/tiktoken-go/issues/6 -// // -// // Every message follows <|start|>{role/name}\n{content}<|end|>\n -// var tokensPerMessage int -// var tokensPerName int -// -// tokensPerMessage = 3 -// tokensPerName = 1 -// -// tokenNum := 0 -// for _, message := range messages { -// tokenNum += tokensPerMessage -// tokenNum += getTokenNum(tokenEncoder, message.Role) -// if message.Content != nil { -// if message.Name != nil { -// tokenNum += tokensPerName -// tokenNum += getTokenNum(tokenEncoder, *message.Name) -// } -// arrayContent := message.ParseContent() -// for _, m := range arrayContent { -// if m.Type == dto.ContentTypeImageURL { -// imageUrl := m.GetImageMedia() -// imageTokenNum, err := getImageToken(info, imageUrl, model, stream) -// if err != nil { -// return 0, err -// } -// tokenNum += imageTokenNum -// log.Printf("image token num: %d", imageTokenNum) -// } else if m.Type == dto.ContentTypeInputAudio { -// // TODO: 音频token数量计算 -// tokenNum += 100 -// } else if m.Type == dto.ContentTypeFile { -// tokenNum += 5000 -// } else if m.Type == dto.ContentTypeVideoUrl { -// tokenNum += 5000 -// } else { -// tokenNum += getTokenNum(tokenEncoder, m.Text) -// } -// } -// } -// } -// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> -// return tokenNum, nil -//} - func CountTokenInput(input any, model string) int { switch v := input.(type) { case string: diff --git a/service/user_notify.go b/service/user_notify.go index 7c864a1b1..c4a3ea91f 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -2,9 +2,12 @@ package service import ( "fmt" + "net/http" + "net/url" "one-api/common" "one-api/dto" "one-api/model" + "one-api/setting" "strings" ) @@ -51,6 +54,13 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data // 获取 webhook secret webhookSecret := userSetting.WebhookSecret return SendWebhookNotify(webhookURLStr, webhookSecret, data) + case dto.NotifyTypeBark: + barkURL := userSetting.BarkUrl + if barkURL == "" { + common.SysLog(fmt.Sprintf("user %d has no bark url, skip sending bark", userId)) + return nil + } + return sendBarkNotify(barkURL, data) } return nil } @@ -64,3 +74,67 @@ func sendEmailNotify(userEmail string, data dto.Notify) error { } return common.SendEmail(data.Title, userEmail, content) } + +func sendBarkNotify(barkURL string, data dto.Notify) error { + // 处理占位符 + content := data.Content + for _, value := range data.Values { + content = strings.Replace(content, dto.ContentValueParam, fmt.Sprintf("%v", value), 1) + } + + // 替换模板变量 + finalURL := strings.ReplaceAll(barkURL, "{{title}}", url.QueryEscape(data.Title)) + finalURL = strings.ReplaceAll(finalURL, "{{content}}", url.QueryEscape(content)) + + // 发送GET请求到Bark + var req *http.Request + var resp *http.Response + var err error + + if setting.EnableWorker() { + // 使用worker发送请求 + workerReq := &WorkerRequest{ + URL: finalURL, + Key: setting.WorkerValidKey, + Method: http.MethodGet, + Headers: map[string]string{ + "User-Agent": "OneAPI-Bark-Notify/1.0", + }, + } + + resp, err = DoWorkerRequest(workerReq) + if err != nil { + return fmt.Errorf("failed to send bark request through worker: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode) + } + } else { + // 直接发送请求 + req, err = http.NewRequest(http.MethodGet, finalURL, nil) + if err != nil { + return fmt.Errorf("failed to create bark request: %v", err) + } + + // 设置User-Agent + req.Header.Set("User-Agent", "OneAPI-Bark-Notify/1.0") + + // 发送请求 + client := GetHttpClient() + resp, err = client.Do(req) + if err != nil { + return fmt.Errorf("failed to send bark request: %v", err) + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("bark request failed with status code: %d", resp.StatusCode) + } + } + + return nil +} diff --git a/setting/console_setting/config.go b/setting/console_setting/config.go index 6327e5584..8cfcd0ed6 100644 --- a/setting/console_setting/config.go +++ b/setting/console_setting/config.go @@ -3,37 +3,37 @@ package console_setting import "one-api/setting/config" type ConsoleSetting struct { - ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串) - UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串) - Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串) - FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串) - ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板 - UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板 - AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板 - FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板 + ApiInfo string `json:"api_info"` // 控制台 API 信息 (JSON 数组字符串) + UptimeKumaGroups string `json:"uptime_kuma_groups"` // Uptime Kuma 分组配置 (JSON 数组字符串) + Announcements string `json:"announcements"` // 系统公告 (JSON 数组字符串) + FAQ string `json:"faq"` // 常见问题 (JSON 数组字符串) + ApiInfoEnabled bool `json:"api_info_enabled"` // 是否启用 API 信息面板 + UptimeKumaEnabled bool `json:"uptime_kuma_enabled"` // 是否启用 Uptime Kuma 面板 + AnnouncementsEnabled bool `json:"announcements_enabled"` // 是否启用系统公告面板 + FAQEnabled bool `json:"faq_enabled"` // 是否启用常见问答面板 } // 默认配置 var defaultConsoleSetting = ConsoleSetting{ - ApiInfo: "", - UptimeKumaGroups: "", - Announcements: "", - FAQ: "", - ApiInfoEnabled: true, - UptimeKumaEnabled: true, - AnnouncementsEnabled: true, - FAQEnabled: true, + ApiInfo: "", + UptimeKumaGroups: "", + Announcements: "", + FAQ: "", + ApiInfoEnabled: true, + UptimeKumaEnabled: true, + AnnouncementsEnabled: true, + FAQEnabled: true, } // 全局实例 var consoleSetting = defaultConsoleSetting func init() { - // 注册到全局配置管理器,键名为 console_setting - config.GlobalConfig.Register("console_setting", &consoleSetting) + // 注册到全局配置管理器,键名为 console_setting + config.GlobalConfig.Register("console_setting", &consoleSetting) } // GetConsoleSetting 获取 ConsoleSetting 配置实例 func GetConsoleSetting() *ConsoleSetting { - return &consoleSetting -} \ No newline at end of file + return &consoleSetting +} diff --git a/setting/console_setting/validation.go b/setting/console_setting/validation.go index fda6453df..529457761 100644 --- a/setting/console_setting/validation.go +++ b/setting/console_setting/validation.go @@ -1,304 +1,304 @@ package console_setting import ( - "encoding/json" - "fmt" - "net/url" - "regexp" - "strings" - "time" - "sort" + "encoding/json" + "fmt" + "net/url" + "regexp" + "sort" + "strings" + "time" ) var ( - urlRegex = regexp.MustCompile(`^https?://(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)*[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?|(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))(?:\:[0-9]{1,5})?(?:/.*)?$`) - dangerousChars = []string{" 50 { - return fmt.Errorf("API信息数量不能超过50个") - } + if len(apiInfoList) > 50 { + return fmt.Errorf("API信息数量不能超过50个") + } - for i, apiInfo := range apiInfoList { - urlStr, ok := apiInfo["url"].(string) - if !ok || urlStr == "" { - return fmt.Errorf("第%d个API信息缺少URL字段", i+1) - } - route, ok := apiInfo["route"].(string) - if !ok || route == "" { - return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1) - } - description, ok := apiInfo["description"].(string) - if !ok || description == "" { - return fmt.Errorf("第%d个API信息缺少说明字段", i+1) - } - color, ok := apiInfo["color"].(string) - if !ok || color == "" { - return fmt.Errorf("第%d个API信息缺少颜色字段", i+1) - } + for i, apiInfo := range apiInfoList { + urlStr, ok := apiInfo["url"].(string) + if !ok || urlStr == "" { + return fmt.Errorf("第%d个API信息缺少URL字段", i+1) + } + route, ok := apiInfo["route"].(string) + if !ok || route == "" { + return fmt.Errorf("第%d个API信息缺少线路描述字段", i+1) + } + description, ok := apiInfo["description"].(string) + if !ok || description == "" { + return fmt.Errorf("第%d个API信息缺少说明字段", i+1) + } + color, ok := apiInfo["color"].(string) + if !ok || color == "" { + return fmt.Errorf("第%d个API信息缺少颜色字段", i+1) + } - if err := validateURL(urlStr, i+1, "API信息"); err != nil { - return err - } + if err := validateURL(urlStr, i+1, "API信息"); err != nil { + return err + } - if len(urlStr) > 500 { - return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1) - } - if len(route) > 100 { - return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1) - } - if len(description) > 200 { - return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1) - } + if len(urlStr) > 500 { + return fmt.Errorf("第%d个API信息的URL长度不能超过500字符", i+1) + } + if len(route) > 100 { + return fmt.Errorf("第%d个API信息的线路描述长度不能超过100字符", i+1) + } + if len(description) > 200 { + return fmt.Errorf("第%d个API信息的说明长度不能超过200字符", i+1) + } - if !validColors[color] { - return fmt.Errorf("第%d个API信息的颜色值不合法", i+1) - } + if !validColors[color] { + return fmt.Errorf("第%d个API信息的颜色值不合法", i+1) + } - if err := checkDangerousContent(description, i+1, "API信息"); err != nil { - return err - } - if err := checkDangerousContent(route, i+1, "API信息"); err != nil { - return err - } - } - return nil + if err := checkDangerousContent(description, i+1, "API信息"); err != nil { + return err + } + if err := checkDangerousContent(route, i+1, "API信息"); err != nil { + return err + } + } + return nil } func GetApiInfo() []map[string]interface{} { - return getJSONList(GetConsoleSetting().ApiInfo) + return getJSONList(GetConsoleSetting().ApiInfo) } func validateAnnouncements(announcementsStr string) error { - list, err := parseJSONArray(announcementsStr, "系统公告") - if err != nil { - return err - } - if len(list) > 100 { - return fmt.Errorf("系统公告数量不能超过100个") - } - validTypes := map[string]bool{ - "default": true, "ongoing": true, "success": true, "warning": true, "error": true, - } - for i, ann := range list { - content, ok := ann["content"].(string) - if !ok || content == "" { - return fmt.Errorf("第%d个公告缺少内容字段", i+1) - } - publishDateAny, exists := ann["publishDate"] - if !exists { - return fmt.Errorf("第%d个公告缺少发布日期字段", i+1) - } - publishDateStr, ok := publishDateAny.(string) - if !ok || publishDateStr == "" { - return fmt.Errorf("第%d个公告的发布日期不能为空", i+1) - } - if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil { - return fmt.Errorf("第%d个公告的发布日期格式错误", i+1) - } - if t, exists := ann["type"]; exists { - if typeStr, ok := t.(string); ok { - if !validTypes[typeStr] { - return fmt.Errorf("第%d个公告的类型值不合法", i+1) - } - } - } - if len(content) > 500 { - return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1) - } - if extra, exists := ann["extra"]; exists { - if extraStr, ok := extra.(string); ok && len(extraStr) > 200 { - return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1) - } - } - } - return nil + list, err := parseJSONArray(announcementsStr, "系统公告") + if err != nil { + return err + } + if len(list) > 100 { + return fmt.Errorf("系统公告数量不能超过100个") + } + validTypes := map[string]bool{ + "default": true, "ongoing": true, "success": true, "warning": true, "error": true, + } + for i, ann := range list { + content, ok := ann["content"].(string) + if !ok || content == "" { + return fmt.Errorf("第%d个公告缺少内容字段", i+1) + } + publishDateAny, exists := ann["publishDate"] + if !exists { + return fmt.Errorf("第%d个公告缺少发布日期字段", i+1) + } + publishDateStr, ok := publishDateAny.(string) + if !ok || publishDateStr == "" { + return fmt.Errorf("第%d个公告的发布日期不能为空", i+1) + } + if _, err := time.Parse(time.RFC3339, publishDateStr); err != nil { + return fmt.Errorf("第%d个公告的发布日期格式错误", i+1) + } + if t, exists := ann["type"]; exists { + if typeStr, ok := t.(string); ok { + if !validTypes[typeStr] { + return fmt.Errorf("第%d个公告的类型值不合法", i+1) + } + } + } + if len(content) > 500 { + return fmt.Errorf("第%d个公告的内容长度不能超过500字符", i+1) + } + if extra, exists := ann["extra"]; exists { + if extraStr, ok := extra.(string); ok && len(extraStr) > 200 { + return fmt.Errorf("第%d个公告的说明长度不能超过200字符", i+1) + } + } + } + return nil } func validateFAQ(faqStr string) error { - list, err := parseJSONArray(faqStr, "FAQ信息") - if err != nil { - return err - } - if len(list) > 100 { - return fmt.Errorf("FAQ数量不能超过100个") - } - for i, faq := range list { - question, ok := faq["question"].(string) - if !ok || question == "" { - return fmt.Errorf("第%d个FAQ缺少问题字段", i+1) - } - answer, ok := faq["answer"].(string) - if !ok || answer == "" { - return fmt.Errorf("第%d个FAQ缺少答案字段", i+1) - } - if len(question) > 200 { - return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1) - } - if len(answer) > 1000 { - return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1) - } - } - return nil + list, err := parseJSONArray(faqStr, "FAQ信息") + if err != nil { + return err + } + if len(list) > 100 { + return fmt.Errorf("FAQ数量不能超过100个") + } + for i, faq := range list { + question, ok := faq["question"].(string) + if !ok || question == "" { + return fmt.Errorf("第%d个FAQ缺少问题字段", i+1) + } + answer, ok := faq["answer"].(string) + if !ok || answer == "" { + return fmt.Errorf("第%d个FAQ缺少答案字段", i+1) + } + if len(question) > 200 { + return fmt.Errorf("第%d个FAQ的问题长度不能超过200字符", i+1) + } + if len(answer) > 1000 { + return fmt.Errorf("第%d个FAQ的答案长度不能超过1000字符", i+1) + } + } + return nil } func getPublishTime(item map[string]interface{}) time.Time { - if v, ok := item["publishDate"]; ok { - if s, ok2 := v.(string); ok2 { - if t, err := time.Parse(time.RFC3339, s); err == nil { - return t - } - } - } - return time.Time{} + if v, ok := item["publishDate"]; ok { + if s, ok2 := v.(string); ok2 { + if t, err := time.Parse(time.RFC3339, s); err == nil { + return t + } + } + } + return time.Time{} } func GetAnnouncements() []map[string]interface{} { - list := getJSONList(GetConsoleSetting().Announcements) - sort.SliceStable(list, func(i, j int) bool { - return getPublishTime(list[i]).After(getPublishTime(list[j])) - }) - return list + list := getJSONList(GetConsoleSetting().Announcements) + sort.SliceStable(list, func(i, j int) bool { + return getPublishTime(list[i]).After(getPublishTime(list[j])) + }) + return list } func GetFAQ() []map[string]interface{} { - return getJSONList(GetConsoleSetting().FAQ) + return getJSONList(GetConsoleSetting().FAQ) } func validateUptimeKumaGroups(groupsStr string) error { - groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置") - if err != nil { - return err - } + groups, err := parseJSONArray(groupsStr, "Uptime Kuma分组配置") + if err != nil { + return err + } - if len(groups) > 20 { - return fmt.Errorf("Uptime Kuma分组数量不能超过20个") - } + if len(groups) > 20 { + return fmt.Errorf("Uptime Kuma分组数量不能超过20个") + } - nameSet := make(map[string]bool) + nameSet := make(map[string]bool) - for i, group := range groups { - categoryName, ok := group["categoryName"].(string) - if !ok || categoryName == "" { - return fmt.Errorf("第%d个分组缺少分类名称字段", i+1) - } - if nameSet[categoryName] { - return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1) - } - nameSet[categoryName] = true - urlStr, ok := group["url"].(string) - if !ok || urlStr == "" { - return fmt.Errorf("第%d个分组缺少URL字段", i+1) - } - slug, ok := group["slug"].(string) - if !ok || slug == "" { - return fmt.Errorf("第%d个分组缺少Slug字段", i+1) - } - description, ok := group["description"].(string) - if !ok { - description = "" - } + for i, group := range groups { + categoryName, ok := group["categoryName"].(string) + if !ok || categoryName == "" { + return fmt.Errorf("第%d个分组缺少分类名称字段", i+1) + } + if nameSet[categoryName] { + return fmt.Errorf("第%d个分组的分类名称与其他分组重复", i+1) + } + nameSet[categoryName] = true + urlStr, ok := group["url"].(string) + if !ok || urlStr == "" { + return fmt.Errorf("第%d个分组缺少URL字段", i+1) + } + slug, ok := group["slug"].(string) + if !ok || slug == "" { + return fmt.Errorf("第%d个分组缺少Slug字段", i+1) + } + description, ok := group["description"].(string) + if !ok { + description = "" + } - if err := validateURL(urlStr, i+1, "分组"); err != nil { - return err - } + if err := validateURL(urlStr, i+1, "分组"); err != nil { + return err + } - if len(categoryName) > 50 { - return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1) - } - if len(urlStr) > 500 { - return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1) - } - if len(slug) > 100 { - return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1) - } - if len(description) > 200 { - return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1) - } + if len(categoryName) > 50 { + return fmt.Errorf("第%d个分组的分类名称长度不能超过50字符", i+1) + } + if len(urlStr) > 500 { + return fmt.Errorf("第%d个分组的URL长度不能超过500字符", i+1) + } + if len(slug) > 100 { + return fmt.Errorf("第%d个分组的Slug长度不能超过100字符", i+1) + } + if len(description) > 200 { + return fmt.Errorf("第%d个分组的描述长度不能超过200字符", i+1) + } - if !slugRegex.MatchString(slug) { - return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1) - } + if !slugRegex.MatchString(slug) { + return fmt.Errorf("第%d个分组的Slug只能包含字母、数字、下划线和连字符", i+1) + } - if err := checkDangerousContent(description, i+1, "分组"); err != nil { - return err - } - if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil { - return err - } - } - return nil + if err := checkDangerousContent(description, i+1, "分组"); err != nil { + return err + } + if err := checkDangerousContent(categoryName, i+1, "分组"); err != nil { + return err + } + } + return nil } func GetUptimeKumaGroups() []map[string]interface{} { - return getJSONList(GetConsoleSetting().UptimeKumaGroups) -} \ No newline at end of file + return getJSONList(GetConsoleSetting().UptimeKumaGroups) +} diff --git a/setting/model_setting/gemini.go b/setting/model_setting/gemini.go index f132fec88..5412155f1 100644 --- a/setting/model_setting/gemini.go +++ b/setting/model_setting/gemini.go @@ -26,6 +26,7 @@ var defaultGeminiSettings = GeminiSettings{ SupportedImagineModels: []string{ "gemini-2.0-flash-exp-image-generation", "gemini-2.0-flash-exp", + "gemini-2.5-flash-image-preview", }, ThinkingAdapterEnabled: false, ThinkingAdapterBudgetTokensPercentage: 0.6, diff --git a/setting/operation_setting/monitor_setting.go b/setting/operation_setting/monitor_setting.go new file mode 100644 index 000000000..1d0bbec40 --- /dev/null +++ b/setting/operation_setting/monitor_setting.go @@ -0,0 +1,34 @@ +package operation_setting + +import ( + "one-api/setting/config" + "os" + "strconv" +) + +type MonitorSetting struct { + AutoTestChannelEnabled bool `json:"auto_test_channel_enabled"` + AutoTestChannelMinutes int `json:"auto_test_channel_minutes"` +} + +// 默认配置 +var monitorSetting = MonitorSetting{ + AutoTestChannelEnabled: false, + AutoTestChannelMinutes: 10, +} + +func init() { + // 注册到全局配置管理器 + config.GlobalConfig.Register("monitor_setting", &monitorSetting) +} + +func GetMonitorSetting() *MonitorSetting { + if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { + frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) + if err == nil && frequency > 0 { + monitorSetting.AutoTestChannelEnabled = true + monitorSetting.AutoTestChannelMinutes = frequency + } + } + return &monitorSetting +} diff --git a/setting/operation_setting/tools.go b/setting/operation_setting/tools.go index 549a1862e..b87265ee1 100644 --- a/setting/operation_setting/tools.go +++ b/setting/operation_setting/tools.go @@ -24,6 +24,10 @@ const ( ClaudeWebSearchPrice = 10.00 ) +const ( + Gemini25FlashImagePreviewImageOutputPrice = 30.00 +) + func GetClaudeWebSearchPricePerThousand() float64 { return ClaudeWebSearchPrice } @@ -65,3 +69,10 @@ func GetGeminiInputAudioPricePerMillionTokens(modelName string) float64 { } return 0 } + +func GetGeminiImageOutputPricePerMillionTokens(modelName string) float64 { + if strings.HasPrefix(modelName, "gemini-2.5-flash-image-preview") { + return Gemini25FlashImagePreviewImageOutputPrice + } + return 0 +} diff --git a/setting/ratio_setting/expose_ratio.go b/setting/ratio_setting/expose_ratio.go index 8fca0bcb0..783d9778e 100644 --- a/setting/ratio_setting/expose_ratio.go +++ b/setting/ratio_setting/expose_ratio.go @@ -5,13 +5,13 @@ import "sync/atomic" var exposeRatioEnabled atomic.Bool func init() { - exposeRatioEnabled.Store(false) + exposeRatioEnabled.Store(false) } func SetExposeRatioEnabled(enabled bool) { - exposeRatioEnabled.Store(enabled) + exposeRatioEnabled.Store(enabled) } func IsExposeRatioEnabled() bool { - return exposeRatioEnabled.Load() -} \ No newline at end of file + return exposeRatioEnabled.Load() +} diff --git a/setting/ratio_setting/exposed_cache.go b/setting/ratio_setting/exposed_cache.go index 9e5b6c300..2fe2cd09b 100644 --- a/setting/ratio_setting/exposed_cache.go +++ b/setting/ratio_setting/exposed_cache.go @@ -1,55 +1,55 @@ package ratio_setting import ( - "sync" - "sync/atomic" - "time" + "sync" + "sync/atomic" + "time" - "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin" ) const exposedDataTTL = 30 * time.Second type exposedCache struct { - data gin.H - expiresAt time.Time + data gin.H + expiresAt time.Time } var ( - exposedData atomic.Value - rebuildMu sync.Mutex + exposedData atomic.Value + rebuildMu sync.Mutex ) func InvalidateExposedDataCache() { - exposedData.Store((*exposedCache)(nil)) + exposedData.Store((*exposedCache)(nil)) } func cloneGinH(src gin.H) gin.H { - dst := make(gin.H, len(src)) - for k, v := range src { - dst[k] = v - } - return dst + dst := make(gin.H, len(src)) + for k, v := range src { + dst[k] = v + } + return dst } func GetExposedData() gin.H { - if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { - return cloneGinH(c.data) - } - rebuildMu.Lock() - defer rebuildMu.Unlock() - if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { - return cloneGinH(c.data) - } - newData := gin.H{ - "model_ratio": GetModelRatioCopy(), - "completion_ratio": GetCompletionRatioCopy(), - "cache_ratio": GetCacheRatioCopy(), - "model_price": GetModelPriceCopy(), - } - exposedData.Store(&exposedCache{ - data: newData, - expiresAt: time.Now().Add(exposedDataTTL), - }) - return cloneGinH(newData) -} \ No newline at end of file + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + rebuildMu.Lock() + defer rebuildMu.Unlock() + if c, ok := exposedData.Load().(*exposedCache); ok && c != nil && time.Now().Before(c.expiresAt) { + return cloneGinH(c.data) + } + newData := gin.H{ + "model_ratio": GetModelRatioCopy(), + "completion_ratio": GetCompletionRatioCopy(), + "cache_ratio": GetCacheRatioCopy(), + "model_price": GetModelPriceCopy(), + } + exposedData.Store(&exposedCache{ + data: newData, + expiresAt: time.Now().Add(exposedDataTTL), + }) + return cloneGinH(newData) +} diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index d295b0b21..1a1b0afa8 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -178,6 +178,7 @@ var defaultModelRatio = map[string]float64{ "gemini-2.5-flash-lite-preview-thinking-*": 0.05, "gemini-2.5-flash-lite-preview-06-17": 0.05, "gemini-2.5-flash": 0.15, + "gemini-2.5-flash-image-preview": 0.15, // $0.30(text/image) / 1M tokens "text-embedding-004": 0.001, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens @@ -293,10 +294,11 @@ var ( ) var defaultCompletionRatio = map[string]float64{ - "gpt-4-gizmo-*": 2, - "gpt-4o-gizmo-*": 3, - "gpt-4-all": 2, - "gpt-image-1": 8, + "gpt-4-gizmo-*": 2, + "gpt-4o-gizmo-*": 3, + "gpt-4-all": 2, + "gpt-image-1": 8, + "gemini-2.5-flash-image-preview": 8.3333333333, } // InitRatioSettings initializes all model related settings maps @@ -541,7 +543,7 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) { if strings.HasPrefix(name, "gemini-2.5-flash-lite") { return 4, false } - return 2.5 / 0.3, true + return 2.5 / 0.3, false } return 4, false } diff --git a/types/error.go b/types/error.go index 20d9c214e..883ee0641 100644 --- a/types/error.go +++ b/types/error.go @@ -145,13 +145,15 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError { Code: e.errorCode, } } + default: + result = OpenAIError{ + Message: e.Error(), + Type: string(e.errorType), + Param: "", + Code: e.errorCode, + } } - result = OpenAIError{ - Message: e.Error(), - Type: string(e.errorType), - Param: "", - Code: e.errorCode, - } + result.Message = common.MaskSensitiveInfo(result.Message) return result } @@ -160,13 +162,16 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { var result ClaudeError switch e.errorType { case ErrorTypeOpenAIError: - openAIError := e.RelayError.(OpenAIError) - result = ClaudeError{ - Message: e.Error(), - Type: fmt.Sprintf("%v", openAIError.Code), + if openAIError, ok := e.RelayError.(OpenAIError); ok { + result = ClaudeError{ + Message: e.Error(), + Type: fmt.Sprintf("%v", openAIError.Code), + } } case ErrorTypeClaudeError: - result = e.RelayError.(ClaudeError) + if claudeError, ok := e.RelayError.(ClaudeError); ok { + result = claudeError + } default: result = ClaudeError{ Message: e.Error(), @@ -180,6 +185,14 @@ func (e *NewAPIError) ToClaudeError() ClaudeError { type NewAPIErrorOptions func(*NewAPIError) func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { + var newErr *NewAPIError + // 保留深层传递的 new err + if errors.As(err, &newErr) { + for _, op := range ops { + op(newErr) + } + return newErr + } e := &NewAPIError{ Err: err, RelayError: nil, @@ -194,8 +207,21 @@ func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPI } func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { - if errorCode == ErrorCodeDoRequestFailed { - err = errors.New("upstream error: do request failed") + var newErr *NewAPIError + // 保留深层传递的 new err + if errors.As(err, &newErr) { + if newErr.RelayError == nil { + openaiError := OpenAIError{ + Message: newErr.Error(), + Type: string(errorCode), + Code: errorCode, + } + newErr.RelayError = openaiError + } + for _, op := range ops { + op(newErr) + } + return newErr } openaiError := OpenAIError{ Message: err.Error(), @@ -300,6 +326,15 @@ func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions { } } +func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions { + return func(e *NewAPIError) { + if common.DebugEnabled { + fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err) + } + e.Err = errors.New(replaceStr) + } +} + func IsRecordErrorLog(e *NewAPIError) bool { if e == nil { return false diff --git a/web/.eslintrc.cjs b/web/.eslintrc.cjs index 5e88871d2..b1afd96f5 100644 --- a/web/.eslintrc.cjs +++ b/web/.eslintrc.cjs @@ -1,34 +1,42 @@ module.exports = { root: true, env: { browser: true, es2021: true, node: true }, - parserOptions: { ecmaVersion: 2020, sourceType: 'module', ecmaFeatures: { jsx: true } }, + parserOptions: { + ecmaVersion: 2020, + sourceType: 'module', + ecmaFeatures: { jsx: true }, + }, plugins: ['header', 'react-hooks'], overrides: [ { files: ['**/*.{js,jsx}'], rules: { - 'header/header': [2, 'block', [ - '', - 'Copyright (C) 2025 QuantumNous', - '', - 'This program is free software: you can redistribute it and/or modify', - 'it under the terms of the GNU Affero General Public License as', - 'published by the Free Software Foundation, either version 3 of the', - 'License, or (at your option) any later version.', - '', - 'This program is distributed in the hope that it will be useful,', - 'but WITHOUT ANY WARRANTY; without even the implied warranty of', - 'MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the', - 'GNU Affero General Public License for more details.', - '', - 'You should have received a copy of the GNU Affero General Public License', - 'along with this program. If not, see .', - '', - 'For commercial licensing, please contact support@quantumnous.com', - '' - ]], - 'no-multiple-empty-lines': ['error', { max: 1 }] - } - } - ] -}; \ No newline at end of file + 'header/header': [ + 2, + 'block', + [ + '', + 'Copyright (C) 2025 QuantumNous', + '', + 'This program is free software: you can redistribute it and/or modify', + 'it under the terms of the GNU Affero General Public License as', + 'published by the Free Software Foundation, either version 3 of the', + 'License, or (at your option) any later version.', + '', + 'This program is distributed in the hope that it will be useful,', + 'but WITHOUT ANY WARRANTY; without even the implied warranty of', + 'MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the', + 'GNU Affero General Public License for more details.', + '', + 'You should have received a copy of the GNU Affero General Public License', + 'along with this program. If not, see .', + '', + 'For commercial licensing, please contact support@quantumnous.com', + '', + ], + ], + 'no-multiple-empty-lines': ['error', { max: 1 }], + }, + }, + ], +}; diff --git a/web/index.html b/web/index.html index 8528f7fa7..09d87ae1a 100644 --- a/web/index.html +++ b/web/index.html @@ -1,19 +1,20 @@ + + + + + + + New API + - - - - - - - New API - - - - -
- - - - \ No newline at end of file + + +
+ + + diff --git a/web/postcss.config.js b/web/postcss.config.js index 590e21a49..5731ce76e 100644 --- a/web/postcss.config.js +++ b/web/postcss.config.js @@ -22,4 +22,4 @@ export default { tailwindcss: {}, autoprefixer: {}, }, -} +}; diff --git a/web/src/App.jsx b/web/src/App.jsx index fc623309c..635742f91 100644 --- a/web/src/App.jsx +++ b/web/src/App.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React, { lazy, Suspense } from 'react'; +import React, { lazy, Suspense, useContext, useMemo } from 'react'; import { Route, Routes, useLocation } from 'react-router-dom'; import Loading from './components/common/ui/Loading'; import User from './pages/User'; @@ -27,6 +27,7 @@ import LoginForm from './components/auth/LoginForm'; import NotFound from './pages/NotFound'; import Forbidden from './pages/Forbidden'; import Setting from './pages/Setting'; +import { StatusContext } from './context/Status'; import PasswordResetForm from './components/auth/PasswordResetForm'; import PasswordResetConfirm from './components/auth/PasswordResetConfirm'; @@ -53,6 +54,29 @@ const About = lazy(() => import('./pages/About')); function App() { const location = useLocation(); + const [statusState] = useContext(StatusContext); + + // 获取模型广场权限配置 + const pricingRequireAuth = useMemo(() => { + const headerNavModulesConfig = statusState?.status?.HeaderNavModules; + if (headerNavModulesConfig) { + try { + const modules = JSON.parse(headerNavModulesConfig); + + // 处理向后兼容性:如果pricing是boolean,默认不需要登录 + if (typeof modules.pricing === 'boolean') { + return false; // 默认不需要登录鉴权 + } + + // 如果是对象格式,使用requireAuth配置 + return modules.pricing?.requireAuth === true; + } catch (error) { + console.error('解析顶栏模块配置失败:', error); + return false; // 默认不需要登录 + } + } + return false; // 默认不需要登录 + }, [statusState?.status?.HeaderNavModules]); return ( @@ -73,10 +97,7 @@ function App() { } /> - } - /> + } /> } key={location.pathname}> - - + pricingRequireAuth ? ( + + } + key={location.pathname} + > + + + + ) : ( + } key={location.pathname}> + + + ) } /> { const [emailLoginLoading, setEmailLoginLoading] = useState(false); const [loginLoading, setLoginLoading] = useState(false); const [resetPasswordLoading, setResetPasswordLoading] = useState(false); - const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = useState(false); + const [otherLoginOptionsLoading, setOtherLoginOptionsLoading] = + useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); const [showTwoFA, setShowTwoFA] = useState(false); @@ -247,10 +241,7 @@ const LoginForm = () => { const handleOIDCClick = () => { setOidcLoading(true); try { - onOIDCClicked( - status.oidc_authorization_endpoint, - status.oidc_client_id - ); + onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id); } finally { // 由于重定向,这里不会执行到,但为了完整性添加 setTimeout(() => setOidcLoading(false), 3000); @@ -306,73 +297,87 @@ const LoginForm = () => { const renderOAuthOptions = () => { return ( -
-
-
- Logo - {systemName} +
+
+
+ Logo + + {systemName} +
- -
- {t('登 录')} + +
+ + {t('登 录')} +
-
-
+
+
{status.wechat_login && ( )} {status.github_oauth && ( )} {status.oidc_enabled && ( )} {status.linuxdo_oauth && ( )} {status.telegram_oauth && ( -
+
{
{!status.self_use_mode_enabled && ( -
+
{t('没有账户?')}{' '} {t('注册')} @@ -418,44 +423,46 @@ const LoginForm = () => { const renderEmailLoginForm = () => { return ( -
-
-
- Logo +
+
+
+ Logo {systemName}
- -
- {t('登 录')} + +
+ + {t('登 录')} +
-
-
+
+ handleChange('username', value)} prefix={} /> handleChange('password', value)} prefix={} /> -
+
- {(status.github_oauth || status.oidc_enabled || status.wechat_login || status.linuxdo_oauth || status.telegram_oauth) && ( + {(status.github_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth) && ( <> {t('或')} -
+
-
- {t('返回登录')} +
+ + + {t('返回登录')} + +
diff --git a/web/src/components/auth/PasswordResetForm.jsx b/web/src/components/auth/PasswordResetForm.jsx index 93bedae20..92afc2afa 100644 --- a/web/src/components/auth/PasswordResetForm.jsx +++ b/web/src/components/auth/PasswordResetForm.jsx @@ -18,7 +18,14 @@ For commercial licensing, please contact support@quantumnous.com */ import React, { useEffect, useState } from 'react'; -import { API, getLogo, showError, showInfo, showSuccess, getSystemName } from '../../helpers'; +import { + API, + getLogo, + showError, + showInfo, + showSuccess, + getSystemName, +} from '../../helpers'; import Turnstile from 'react-turnstile'; import { Button, Card, Form, Typography } from '@douyinfe/semi-ui'; import { IconMail } from '@douyinfe/semi-icons'; @@ -97,57 +104,77 @@ const PasswordResetForm = () => { } return ( -
+
{/* 背景模糊晕染球 */} -
-
-
-
-
-
- Logo - {systemName} +
+
+
+
+
+
+ Logo + + {systemName} +
- -
- {t('密码重置')} + +
+ + {t('密码重置')} +
-
-
+
+ } /> -
+
-
- {t('想起来了?')} {t('登录')} +
+ + {t('想起来了?')}{' '} + + {t('登录')} + +
{turnstileEnabled && ( -
+
{ diff --git a/web/src/components/auth/RegisterForm.jsx b/web/src/components/auth/RegisterForm.jsx index 0b95d504f..9c98bdc3a 100644 --- a/web/src/components/auth/RegisterForm.jsx +++ b/web/src/components/auth/RegisterForm.jsx @@ -27,20 +27,19 @@ import { showSuccess, updateAPI, getSystemName, - setUserData + setUserData, } from '../../helpers'; import Turnstile from 'react-turnstile'; -import { - Button, - Card, - Divider, - Form, - Icon, - Modal, -} from '@douyinfe/semi-ui'; +import { Button, Card, Divider, Form, Icon, Modal } from '@douyinfe/semi-ui'; import Title from '@douyinfe/semi-ui/lib/es/typography/title'; import Text from '@douyinfe/semi-ui/lib/es/typography/text'; -import { IconGithubLogo, IconMail, IconUser, IconLock, IconKey } from '@douyinfe/semi-icons'; +import { + IconGithubLogo, + IconMail, + IconUser, + IconLock, + IconKey, +} from '@douyinfe/semi-icons'; import { onGitHubOAuthClicked, onLinuxDOOAuthClicked, @@ -78,7 +77,8 @@ const RegisterForm = () => { const [emailRegisterLoading, setEmailRegisterLoading] = useState(false); const [registerLoading, setRegisterLoading] = useState(false); const [verificationCodeLoading, setVerificationCodeLoading] = useState(false); - const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = useState(false); + const [otherRegisterOptionsLoading, setOtherRegisterOptionsLoading] = + useState(false); const [wechatCodeSubmitLoading, setWechatCodeSubmitLoading] = useState(false); const [disableButton, setDisableButton] = useState(false); const [countdown, setCountdown] = useState(30); @@ -236,10 +236,7 @@ const RegisterForm = () => { const handleOIDCClick = () => { setOidcLoading(true); try { - onOIDCClicked( - status.oidc_authorization_endpoint, - status.oidc_client_id - ); + onOIDCClicked(status.oidc_authorization_endpoint, status.oidc_client_id); } finally { setTimeout(() => setOidcLoading(false), 3000); } @@ -303,73 +300,87 @@ const RegisterForm = () => { const renderOAuthOptions = () => { return ( -
-
-
- Logo - {systemName} +
+
+
+ Logo + + {systemName} +
- -
- {t('注 册')} + +
+ + {t('注 册')} +
-
-
+
+
{status.wechat_login && ( )} {status.github_oauth && ( )} {status.oidc_enabled && ( )} {status.linuxdo_oauth && ( )} {status.telegram_oauth && ( -
+
{
-
- {t('已有账户?')} {t('登录')} +
+ + {t('已有账户?')}{' '} + + {t('登录')} + +
@@ -405,44 +424,48 @@ const RegisterForm = () => { const renderEmailRegisterForm = () => { return ( -
-
-
- Logo - {systemName} +
+
+
+ Logo + + {systemName} +
- -
- {t('注 册')} + +
+ + {t('注 册')} +
-
-
+
+ handleChange('username', value)} prefix={} /> handleChange('password', value)} prefix={} /> handleChange('password2', value)} prefix={} /> @@ -450,11 +473,11 @@ const RegisterForm = () => { {showEmailVerification && ( <> handleChange('email', value)} prefix={} suffix={ @@ -463,27 +486,31 @@ const RegisterForm = () => { loading={verificationCodeLoading} disabled={disableButton || verificationCodeLoading} > - {disableButton ? `${t('重新发送')} (${countdown})` : t('获取验证码')} + {disableButton + ? `${t('重新发送')} (${countdown})` + : t('获取验证码')} } /> handleChange('verification_code', value)} + name='verification_code' + onChange={(value) => + handleChange('verification_code', value) + } prefix={} /> )} -
+
- {(status.github_oauth || status.oidc_enabled || status.wechat_login || status.linuxdo_oauth || status.telegram_oauth) && ( + {(status.github_oauth || + status.oidc_enabled || + status.wechat_login || + status.linuxdo_oauth || + status.telegram_oauth) && ( <> {t('或')} -
+
-
- +
+ 提示:
• 验证码每30秒更新一次
• 如果无法获取验证码,请使用备用码 -
- • 每个备用码只能使用一次 +
• 每个备用码只能使用一次
@@ -145,39 +151,41 @@ const TwoFAVerification = ({ onSuccess, onBack, isModal = false }) => { } return ( -
+
两步验证 - + 请输入认证器应用显示的验证码完成登录
-
- +
+ 提示:
• 验证码每30秒更新一次
• 如果无法获取验证码,请使用备用码 -
- • 每个备用码只能使用一次 +
• 每个备用码只能使用一次
@@ -227,4 +241,4 @@ const TwoFAVerification = ({ onSuccess, onBack, isModal = false }) => { ); }; -export default TwoFAVerification; \ No newline at end of file +export default TwoFAVerification; diff --git a/web/src/components/common/markdown/MarkdownRenderer.jsx b/web/src/components/common/markdown/MarkdownRenderer.jsx index 820f2bbf6..f1283a640 100644 --- a/web/src/components/common/markdown/MarkdownRenderer.jsx +++ b/web/src/components/common/markdown/MarkdownRenderer.jsx @@ -160,7 +160,7 @@ export function PreCode(props) { }} >
@@ -367,7 +374,16 @@ function _MarkdownContent(props) { components={{ pre: PreCode, code: CustomCode, - p: (pProps) =>

, + p: (pProps) => ( +

+ ), a: (aProps) => { const href = aProps.href || ''; if (/\.(aac|mp3|opus|wav)$/.test(href)) { @@ -379,13 +395,16 @@ function _MarkdownContent(props) { } if (/\.(3gp|3g2|webm|ogv|mpeg|mp4|avi)$/.test(href)) { return ( -

, - h2: (props) =>

, - h3: (props) =>

, - h4: (props) =>

, - h5: (props) =>

, - h6: (props) =>
, + h1: (props) => ( +

+ ), + h2: (props) => ( +

+ ), + h3: (props) => ( +

+ ), + h4: (props) => ( +

+ ), + h5: (props) => ( +

+ ), + h6: (props) => ( +
+ ), blockquote: (props) => (
), - ul: (props) =>
    , - ol: (props) =>
      , - li: (props) =>
    1. , + ul: (props) => ( +
        + ), + ol: (props) => ( +
          + ), + li: (props) => ( +
        1. + ), table: (props) => (
          @@ -496,25 +614,29 @@ export function MarkdownRenderer(props) { color: 'var(--semi-color-text-0)', ...style, }} - dir="auto" + dir='auto' {...otherProps} > {loading ? ( -
          -
          +
          +
          正在渲染...
          ) : ( @@ -529,4 +651,4 @@ export function MarkdownRenderer(props) { ); } -export default MarkdownRenderer; \ No newline at end of file +export default MarkdownRenderer; diff --git a/web/src/components/common/markdown/markdown.css b/web/src/components/common/markdown/markdown.css index 3b5c1067d..e1e9e9cb4 100644 --- a/web/src/components/common/markdown/markdown.css +++ b/web/src/components/common/markdown/markdown.css @@ -59,12 +59,12 @@ } .user-message a { - color: #87CEEB !important; + color: #87ceeb !important; /* 浅蓝色链接 */ } .user-message a:hover { - color: #B0E0E6 !important; + color: #b0e0e6 !important; /* hover时更浅的蓝色 */ } @@ -298,7 +298,12 @@ pre:hover .copy-code-button { .markdown-body hr { border: none; height: 1px; - background: linear-gradient(to right, transparent, var(--semi-color-border), transparent); + background: linear-gradient( + to right, + transparent, + var(--semi-color-border), + transparent + ); margin: 24px 0; } @@ -332,7 +337,7 @@ pre:hover .copy-code-button { } /* 任务列表样式 */ -.markdown-body input[type="checkbox"] { +.markdown-body input[type='checkbox'] { margin-right: 8px; transform: scale(1.1); } @@ -441,4 +446,4 @@ pre:hover .copy-code-button { .animate-fade-in { animation: fade-in 0.6s cubic-bezier(0.22, 1, 0.36, 1) both; will-change: opacity, transform; -} \ No newline at end of file +} diff --git a/web/src/components/common/modals/TwoFactorAuthModal.jsx b/web/src/components/common/modals/TwoFactorAuthModal.jsx new file mode 100644 index 000000000..b0fc28e2a --- /dev/null +++ b/web/src/components/common/modals/TwoFactorAuthModal.jsx @@ -0,0 +1,146 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { useTranslation } from 'react-i18next'; +import { Modal, Button, Input, Typography } from '@douyinfe/semi-ui'; + +/** + * 可复用的两步验证模态框组件 + * @param {Object} props + * @param {boolean} props.visible - 是否显示模态框 + * @param {string} props.code - 验证码值 + * @param {boolean} props.loading - 是否正在验证 + * @param {Function} props.onCodeChange - 验证码变化回调 + * @param {Function} props.onVerify - 验证回调 + * @param {Function} props.onCancel - 取消回调 + * @param {string} props.title - 模态框标题 + * @param {string} props.description - 验证描述文本 + * @param {string} props.placeholder - 输入框占位文本 + */ +const TwoFactorAuthModal = ({ + visible, + code, + loading, + onCodeChange, + onVerify, + onCancel, + title, + description, + placeholder, +}) => { + const { t } = useTranslation(); + + const handleKeyDown = (e) => { + if (e.key === 'Enter' && code && !loading) { + onVerify(); + } + }; + + return ( + +
          + + + +
          + {title || t('安全验证')} +
          + } + visible={visible} + onCancel={onCancel} + footer={ + <> + + + + } + width={500} + style={{ maxWidth: '90vw' }} + > +
          + {/* 安全提示 */} +
          +
          + + + +
          + + {t('安全验证')} + + + {description || t('为了保护账户安全,请验证您的两步验证码。')} + +
          +
          +
          + + {/* 验证码输入 */} +
          + + {t('验证身份')} + + + + {t('支持6位TOTP验证码或8位备用码')} + +
          +
          + + ); +}; + +export default TwoFactorAuthModal; diff --git a/web/src/components/common/ui/CardPro.jsx b/web/src/components/common/ui/CardPro.jsx index 3e1247229..2c95f97c7 100644 --- a/web/src/components/common/ui/CardPro.jsx +++ b/web/src/components/common/ui/CardPro.jsx @@ -27,15 +27,15 @@ const { Text } = Typography; /** * CardPro 高级卡片组件 - * + * * 布局分为6个区域: * 1. 统计信息区域 (statsArea) - * 2. 描述信息区域 (descriptionArea) + * 2. 描述信息区域 (descriptionArea) * 3. 类型切换/标签区域 (tabsArea) * 4. 操作按钮区域 (actionsArea) * 5. 搜索表单区域 (searchArea) * 6. 分页区域 (paginationArea) - 固定在卡片底部 - * + * * 支持三种布局类型: * - type1: 操作型 (如TokensTable) - 描述信息 + 操作按钮 + 搜索表单 * - type2: 查询型 (如LogsTable) - 统计信息 + 搜索表单 @@ -71,47 +71,38 @@ const CardPro = ({ const hasMobileHideableContent = actionsArea || searchArea; const renderHeader = () => { - const hasContent = statsArea || descriptionArea || tabsArea || actionsArea || searchArea; + const hasContent = + statsArea || descriptionArea || tabsArea || actionsArea || searchArea; if (!hasContent) return null; return ( -
          +
          {/* 统计信息区域 - 用于type2 */} - {type === 'type2' && statsArea && ( - <> - {statsArea} - - )} + {type === 'type2' && statsArea && <>{statsArea}} {/* 描述信息区域 - 用于type1和type3 */} {(type === 'type1' || type === 'type3') && descriptionArea && ( - <> - {descriptionArea} - + <>{descriptionArea} )} {/* 第一个分隔线 - 在描述信息或统计信息后面 */} {((type === 'type1' || type === 'type3') && descriptionArea) || - (type === 'type2' && statsArea) ? ( - + (type === 'type2' && statsArea) ? ( + ) : null} {/* 类型切换/标签区域 - 主要用于type3 */} - {type === 'type3' && tabsArea && ( - <> - {tabsArea} - - )} + {type === 'type3' && tabsArea && <>{tabsArea}} {/* 移动端操作切换按钮 */} {isMobile && hasMobileHideableContent && ( <> -
          +
          ); @@ -214,4 +197,4 @@ CardPro.propTypes = { t: PropTypes.func, }; -export default CardPro; \ No newline at end of file +export default CardPro; diff --git a/web/src/components/common/ui/CardTable.jsx b/web/src/components/common/ui/CardTable.jsx index f7f443dbd..8a331d07e 100644 --- a/web/src/components/common/ui/CardTable.jsx +++ b/web/src/components/common/ui/CardTable.jsx @@ -19,7 +19,15 @@ For commercial licensing, please contact support@quantumnous.com import React, { useState, useEffect, useRef } from 'react'; import { useTranslation } from 'react-i18next'; -import { Table, Card, Skeleton, Pagination, Empty, Button, Collapsible } from '@douyinfe/semi-ui'; +import { + Table, + Card, + Skeleton, + Pagination, + Empty, + Button, + Collapsible, +} from '@douyinfe/semi-ui'; import { IconChevronDown, IconChevronUp } from '@douyinfe/semi-icons'; import PropTypes from 'prop-types'; import { useIsMobile } from '../../../hooks/common/useIsMobile'; @@ -27,7 +35,7 @@ import { useMinimumLoadingTime } from '../../../hooks/common/useMinimumLoadingTi /** * CardTable 响应式表格组件 - * + * * 在桌面端渲染 Semi-UI 的 Table 组件,在移动端则将每一行数据渲染成 Card 形式。 * 该组件与 Table 组件的大部分 API 保持一致,只需将原 Table 换成 CardTable 即可。 */ @@ -75,18 +83,22 @@ const CardTable = ({ const renderSkeletonCard = (key) => { const placeholder = ( -
          +
          {visibleCols.map((col, idx) => { if (!col.title) { return ( -
          +
          ); } return ( -
          +
          + ); }; return ( -
          +
          {[1, 2, 3].map((i) => renderSkeletonCard(i))}
          ); @@ -127,9 +139,12 @@ const CardTable = ({ (!tableProps.rowExpandable || tableProps.rowExpandable(record)); return ( - + {columns.map((col, colIdx) => { - if (tableProps?.visibleColumns && !tableProps.visibleColumns[col.key]) { + if ( + tableProps?.visibleColumns && + !tableProps.visibleColumns[col.key] + ) { return null; } @@ -140,7 +155,7 @@ const CardTable = ({ if (!title) { return ( -
          +
          {cellContent}
          ); @@ -149,14 +164,16 @@ const CardTable = ({ return (
          - + {title} -
          - {cellContent !== undefined && cellContent !== null ? cellContent : '-'} +
          + {cellContent !== undefined && cellContent !== null + ? cellContent + : '-'}
          ); @@ -177,7 +194,7 @@ const CardTable = ({ {showDetails ? t('收起') : t('详情')} -
          +
          {tableProps.expandedRowRender(record, index)}
          @@ -190,19 +207,23 @@ const CardTable = ({ if (isEmpty) { if (tableProps.empty) return tableProps.empty; return ( -
          - +
          +
          ); } return ( -
          +
          {dataSource.map((record, index) => ( - + ))} {!hidePagination && tableProps.pagination && dataSource.length > 0 && ( -
          +
          )} @@ -218,4 +239,4 @@ CardTable.propTypes = { hidePagination: PropTypes.bool, }; -export default CardTable; \ No newline at end of file +export default CardTable; diff --git a/web/src/components/common/ui/ChannelKeyDisplay.jsx b/web/src/components/common/ui/ChannelKeyDisplay.jsx new file mode 100644 index 000000000..79aa3eec7 --- /dev/null +++ b/web/src/components/common/ui/ChannelKeyDisplay.jsx @@ -0,0 +1,280 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { useTranslation } from 'react-i18next'; +import { Card, Button, Typography, Tag } from '@douyinfe/semi-ui'; +import { copy, showSuccess } from '../../../helpers'; + +/** + * 解析密钥数据,支持多种格式 + * @param {string} keyData - 密钥数据 + * @param {Function} t - 翻译函数 + * @returns {Array} 解析后的密钥数组 + */ +const parseChannelKeys = (keyData, t) => { + if (!keyData) return []; + + const trimmed = keyData.trim(); + + // 检查是否是JSON数组格式(如Vertex AI) + if (trimmed.startsWith('[')) { + try { + const parsed = JSON.parse(trimmed); + if (Array.isArray(parsed)) { + return parsed.map((item, index) => ({ + id: index, + content: + typeof item === 'string' ? item : JSON.stringify(item, null, 2), + type: typeof item === 'string' ? 'text' : 'json', + label: `${t('密钥')} ${index + 1}`, + })); + } + } catch (e) { + // 如果解析失败,按普通文本处理 + console.warn('Failed to parse JSON keys:', e); + } + } + + // 检查是否是多行密钥(按换行符分割) + const lines = trimmed.split('\n').filter((line) => line.trim()); + if (lines.length > 1) { + return lines.map((line, index) => ({ + id: index, + content: line.trim(), + type: 'text', + label: `${t('密钥')} ${index + 1}`, + })); + } + + // 单个密钥 + return [ + { + id: 0, + content: trimmed, + type: trimmed.startsWith('{') ? 'json' : 'text', + label: t('密钥'), + }, + ]; +}; + +/** + * 可复用的密钥显示组件 + * @param {Object} props + * @param {string} props.keyData - 密钥数据 + * @param {boolean} props.showSuccessIcon - 是否显示成功图标 + * @param {string} props.successText - 成功文本 + * @param {boolean} props.showWarning - 是否显示安全警告 + * @param {string} props.warningText - 警告文本 + */ +const ChannelKeyDisplay = ({ + keyData, + showSuccessIcon = true, + successText, + showWarning = true, + warningText, +}) => { + const { t } = useTranslation(); + + const parsedKeys = parseChannelKeys(keyData, t); + const isMultipleKeys = parsedKeys.length > 1; + + const handleCopyAll = () => { + copy(keyData); + showSuccess(t('所有密钥已复制到剪贴板')); + }; + + const handleCopyKey = (content) => { + copy(content); + showSuccess(t('密钥已复制到剪贴板')); + }; + + return ( +
          + {/* 成功状态 */} + {showSuccessIcon && ( +
          + + + + + {successText || t('验证成功')} + +
          + )} + + {/* 密钥内容 */} +
          +
          + + {isMultipleKeys ? t('渠道密钥列表') : t('渠道密钥')} + + {isMultipleKeys && ( +
          + + {t('共 {{count}} 个密钥', { count: parsedKeys.length })} + + +
          + )} +
          + +
          + {parsedKeys.map((keyItem) => ( + +
          +
          + + {keyItem.label} + +
          + {keyItem.type === 'json' && ( + + {t('JSON')} + + )} + +
          +
          + +
          + + {keyItem.content} + +
          + + {keyItem.type === 'json' && ( + + {t('JSON格式密钥,请确保格式正确')} + + )} +
          +
          + ))} +
          + + {isMultipleKeys && ( +
          + + + + + {t( + '检测到多个密钥,您可以单独复制每个密钥,或点击复制全部获取完整内容。', + )} + +
          + )} +
          + + {/* 安全警告 */} + {showWarning && ( +
          +
          + + + +
          + + {t('安全提醒')} + + + {warningText || + t( + '请妥善保管密钥信息,不要泄露给他人。如有安全疑虑,请及时更换密钥。', + )} + +
          +
          +
          + )} +
          + ); +}; + +export default ChannelKeyDisplay; diff --git a/web/src/components/common/ui/CompactModeToggle.jsx b/web/src/components/common/ui/CompactModeToggle.jsx index 631156ee1..40da0abc0 100644 --- a/web/src/components/common/ui/CompactModeToggle.jsx +++ b/web/src/components/common/ui/CompactModeToggle.jsx @@ -65,4 +65,4 @@ CompactModeToggle.propTypes = { className: PropTypes.string, }; -export default CompactModeToggle; \ No newline at end of file +export default CompactModeToggle; diff --git a/web/src/components/common/ui/JSONEditor.jsx b/web/src/components/common/ui/JSONEditor.jsx index 4acbe270f..d89753872 100644 --- a/web/src/components/common/ui/JSONEditor.jsx +++ b/web/src/components/common/ui/JSONEditor.jsx @@ -36,11 +36,7 @@ import { Divider, Tooltip, } from '@douyinfe/semi-ui'; -import { - IconPlus, - IconDelete, - IconAlertTriangle, -} from '@douyinfe/semi-icons'; +import { IconPlus, IconDelete, IconAlertTriangle } from '@douyinfe/semi-icons'; const { Text } = Typography; @@ -88,7 +84,7 @@ const JSONEditor = ({ // 将键值对数组转换为对象(重复键时后面的会覆盖前面的) const keyValueArrayToObject = useCallback((arr) => { const result = {}; - arr.forEach(item => { + arr.forEach((item) => { if (item.key) { result[item.key] = item.value; } @@ -115,7 +111,8 @@ const JSONEditor = ({ // 手动模式下的本地文本缓冲 const [manualText, setManualText] = useState(() => { if (typeof value === 'string') return value; - if (value && typeof value === 'object') return JSON.stringify(value, null, 2); + if (value && typeof value === 'object') + return JSON.stringify(value, null, 2); return ''; }); @@ -140,7 +137,7 @@ const JSONEditor = ({ const keyCount = {}; const duplicates = new Set(); - keyValuePairs.forEach(pair => { + keyValuePairs.forEach((pair) => { if (pair.key) { keyCount[pair.key] = (keyCount[pair.key] || 0) + 1; if (keyCount[pair.key] > 1) { @@ -178,51 +175,65 @@ const JSONEditor = ({ useEffect(() => { if (editMode !== 'manual') { if (typeof value === 'string') setManualText(value); - else if (value && typeof value === 'object') setManualText(JSON.stringify(value, null, 2)); + else if (value && typeof value === 'object') + setManualText(JSON.stringify(value, null, 2)); else setManualText(''); } }, [value, editMode]); // 处理可视化编辑的数据变化 - const handleVisualChange = useCallback((newPairs) => { - setKeyValuePairs(newPairs); - const jsonObject = keyValueArrayToObject(newPairs); - const jsonString = Object.keys(jsonObject).length === 0 ? '' : JSON.stringify(jsonObject, null, 2); + const handleVisualChange = useCallback( + (newPairs) => { + setKeyValuePairs(newPairs); + const jsonObject = keyValueArrayToObject(newPairs); + const jsonString = + Object.keys(jsonObject).length === 0 + ? '' + : JSON.stringify(jsonObject, null, 2); - setJsonError(''); + setJsonError(''); - // 通过formApi设置值 - if (formApi && field) { - formApi.setValue(field, jsonString); - } + // 通过formApi设置值 + if (formApi && field) { + formApi.setValue(field, jsonString); + } - onChange?.(jsonString); - }, [onChange, formApi, field, keyValueArrayToObject]); + onChange?.(jsonString); + }, + [onChange, formApi, field, keyValueArrayToObject], + ); // 处理手动编辑的数据变化 - const handleManualChange = useCallback((newValue) => { - setManualText(newValue); - if (newValue && newValue.trim()) { - try { - const parsed = JSON.parse(newValue); - setKeyValuePairs(objectToKeyValueArray(parsed, keyValuePairs)); + const handleManualChange = useCallback( + (newValue) => { + setManualText(newValue); + if (newValue && newValue.trim()) { + try { + const parsed = JSON.parse(newValue); + setKeyValuePairs(objectToKeyValueArray(parsed, keyValuePairs)); + setJsonError(''); + onChange?.(newValue); + } catch (error) { + setJsonError(error.message); + } + } else { + setKeyValuePairs([]); setJsonError(''); - onChange?.(newValue); - } catch (error) { - setJsonError(error.message); + onChange?.(''); } - } else { - setKeyValuePairs([]); - setJsonError(''); - onChange?.(''); - } - }, [onChange, objectToKeyValueArray, keyValuePairs]); + }, + [onChange, objectToKeyValueArray, keyValuePairs], + ); // 切换编辑模式 const toggleEditMode = useCallback(() => { if (editMode === 'visual') { const jsonObject = keyValueArrayToObject(keyValuePairs); - setManualText(Object.keys(jsonObject).length === 0 ? '' : JSON.stringify(jsonObject, null, 2)); + setManualText( + Object.keys(jsonObject).length === 0 + ? '' + : JSON.stringify(jsonObject, null, 2), + ); setEditMode('manual'); } else { try { @@ -242,12 +253,19 @@ const JSONEditor = ({ return; } } - }, [editMode, value, manualText, keyValuePairs, keyValueArrayToObject, objectToKeyValueArray]); + }, [ + editMode, + value, + manualText, + keyValuePairs, + keyValueArrayToObject, + objectToKeyValueArray, + ]); // 添加键值对 const addKeyValue = useCallback(() => { const newPairs = [...keyValuePairs]; - const existingKeys = newPairs.map(p => p.key); + const existingKeys = newPairs.map((p) => p.key); let counter = 1; let newKey = `field_${counter}`; while (existingKeys.includes(newKey)) { @@ -257,32 +275,41 @@ const JSONEditor = ({ newPairs.push({ id: generateUniqueId(), key: newKey, - value: '' + value: '', }); handleVisualChange(newPairs); }, [keyValuePairs, handleVisualChange]); // 删除键值对 - const removeKeyValue = useCallback((id) => { - const newPairs = keyValuePairs.filter(pair => pair.id !== id); - handleVisualChange(newPairs); - }, [keyValuePairs, handleVisualChange]); + const removeKeyValue = useCallback( + (id) => { + const newPairs = keyValuePairs.filter((pair) => pair.id !== id); + handleVisualChange(newPairs); + }, + [keyValuePairs, handleVisualChange], + ); // 更新键名 - const updateKey = useCallback((id, newKey) => { - const newPairs = keyValuePairs.map(pair => - pair.id === id ? { ...pair, key: newKey } : pair - ); - handleVisualChange(newPairs); - }, [keyValuePairs, handleVisualChange]); + const updateKey = useCallback( + (id, newKey) => { + const newPairs = keyValuePairs.map((pair) => + pair.id === id ? { ...pair, key: newKey } : pair, + ); + handleVisualChange(newPairs); + }, + [keyValuePairs, handleVisualChange], + ); // 更新值 - const updateValue = useCallback((id, newValue) => { - const newPairs = keyValuePairs.map(pair => - pair.id === id ? { ...pair, value: newValue } : pair - ); - handleVisualChange(newPairs); - }, [keyValuePairs, handleVisualChange]); + const updateValue = useCallback( + (id, newValue) => { + const newPairs = keyValuePairs.map((pair) => + pair.id === id ? { ...pair, value: newValue } : pair, + ); + handleVisualChange(newPairs); + }, + [keyValuePairs, handleVisualChange], + ); // 填入模板 const fillTemplate = useCallback(() => { @@ -298,7 +325,14 @@ const JSONEditor = ({ onChange?.(templateString); setJsonError(''); } - }, [template, onChange, formApi, field, objectToKeyValueArray, keyValuePairs]); + }, [ + template, + onChange, + formApi, + field, + objectToKeyValueArray, + keyValuePairs, + ]); // 渲染值输入控件(支持嵌套) const renderValueInput = (pairId, value) => { @@ -306,12 +340,12 @@ const JSONEditor = ({ if (valueType === 'boolean') { return ( -
          +
          updateValue(pairId, newValue)} /> - + {value ? t('true') : t('false')}
          @@ -373,29 +407,29 @@ const JSONEditor = ({ // 渲染键值对编辑器 const renderKeyValueEditor = () => { return ( -
          +
          {/* 重复键警告 */} {duplicateKeys.size > 0 && ( } description={
          {t('存在重复的键名:')} {Array.from(duplicateKeys).join(', ')}
          - + {t('注意:JSON中重复的键只会保留最后一个同名键的值')}
          } - className="mb-3" + className='mb-3' /> )} {keyValuePairs.length === 0 && ( -
          - +
          + {t('暂无数据,点击下方按钮添加键值对')}
          @@ -403,13 +437,14 @@ const JSONEditor = ({ {keyValuePairs.map((pair, index) => { const isDuplicate = duplicateKeys.has(pair.key); - const isLastDuplicate = isDuplicate && - keyValuePairs.slice(index + 1).every(p => p.key !== pair.key); + const isLastDuplicate = + isDuplicate && + keyValuePairs.slice(index + 1).every((p) => p.key !== pair.key); return ( - -
          -
          + +
          +
          )}
          - - {renderValueInput(pair.id, pair.value)} - + {renderValueInput(pair.id, pair.value)}-
          +
          @@ -546,8 +582,8 @@ const JSONEditor = ({
          @@ -590,9 +626,9 @@ const JSONEditor = ({ +
          { if (key === 'manual' && editMode === 'visual') { @@ -602,16 +638,12 @@ const JSONEditor = ({ } }} > - - + + {template && templateLabel && ( - )} @@ -619,14 +651,14 @@ const JSONEditor = ({ } headerStyle={{ padding: '12px 16px' }} bodyStyle={{ padding: '16px' }} - className="!rounded-2xl" + className='!rounded-2xl' > {/* JSON错误提示 */} {hasJsonError && ( )} @@ -668,17 +700,15 @@ const JSONEditor = ({ {/* 额外文本显示在卡片底部 */} {extraText && ( - {extraText} + + {extraText} + )} - {extraFooter && ( -
          - {extraFooter} -
          - )} + {extraFooter &&
          {extraFooter}
          } ); }; -export default JSONEditor; \ No newline at end of file +export default JSONEditor; diff --git a/web/src/components/common/ui/Loading.jsx b/web/src/components/common/ui/Loading.jsx index 60f947486..a2fc6f8e9 100644 --- a/web/src/components/common/ui/Loading.jsx +++ b/web/src/components/common/ui/Loading.jsx @@ -21,13 +21,9 @@ import React from 'react'; import { Spin } from '@douyinfe/semi-ui'; const Loading = ({ size = 'small' }) => { - return ( -
          - +
          +
          ); }; diff --git a/web/src/components/common/ui/RenderUtils.jsx b/web/src/components/common/ui/RenderUtils.jsx index 26a72e16f..3411649ce 100644 --- a/web/src/components/common/ui/RenderUtils.jsx +++ b/web/src/components/common/ui/RenderUtils.jsx @@ -57,4 +57,4 @@ export const renderDescription = (text, maxWidth = 200) => { {text || '-'} ); -}; \ No newline at end of file +}; diff --git a/web/src/components/common/ui/ScrollableContainer.jsx b/web/src/components/common/ui/ScrollableContainer.jsx index 0137c64b8..441c8c030 100644 --- a/web/src/components/common/ui/ScrollableContainer.jsx +++ b/web/src/components/common/ui/ScrollableContainer.jsx @@ -24,197 +24,219 @@ import React, { useCallback, useMemo, useImperativeHandle, - forwardRef + forwardRef, } from 'react'; /** * ScrollableContainer 可滚动容器组件 - * + * * 提供自动检测滚动状态和显示渐变指示器的功能 * 当内容超出容器高度且未滚动到底部时,会显示底部渐变指示器 - * + * */ -const ScrollableContainer = forwardRef(({ - children, - maxHeight = '24rem', - className = '', - contentClassName = 'p-2', - fadeIndicatorClassName = '', - checkInterval = 100, - scrollThreshold = 5, - debounceDelay = 16, // ~60fps - onScroll, - onScrollStateChange, - ...props -}, ref) => { - const scrollRef = useRef(null); - const containerRef = useRef(null); - const debounceTimerRef = useRef(null); - const resizeObserverRef = useRef(null); - const onScrollStateChangeRef = useRef(onScrollStateChange); - const onScrollRef = useRef(onScroll); - - const [showScrollHint, setShowScrollHint] = useState(false); - - useEffect(() => { - onScrollStateChangeRef.current = onScrollStateChange; - }, [onScrollStateChange]); - - useEffect(() => { - onScrollRef.current = onScroll; - }, [onScroll]); - - const debounce = useCallback((func, delay) => { - return (...args) => { - if (debounceTimerRef.current) { - clearTimeout(debounceTimerRef.current); - } - debounceTimerRef.current = setTimeout(() => func(...args), delay); - }; - }, []); - - const checkScrollable = useCallback(() => { - if (!scrollRef.current) return; - - const element = scrollRef.current; - const isScrollable = element.scrollHeight > element.clientHeight; - const isAtBottom = element.scrollTop + element.clientHeight >= element.scrollHeight - scrollThreshold; - const shouldShowHint = isScrollable && !isAtBottom; - - setShowScrollHint(shouldShowHint); - - if (onScrollStateChangeRef.current) { - onScrollStateChangeRef.current({ - isScrollable, - isAtBottom, - showScrollHint: shouldShowHint, - scrollTop: element.scrollTop, - scrollHeight: element.scrollHeight, - clientHeight: element.clientHeight - }); - } - }, [scrollThreshold]); - - const debouncedCheckScrollable = useMemo(() => - debounce(checkScrollable, debounceDelay), - [debounce, checkScrollable, debounceDelay] - ); - - const handleScroll = useCallback((e) => { - debouncedCheckScrollable(); - if (onScrollRef.current) { - onScrollRef.current(e); - } - }, [debouncedCheckScrollable]); - - useImperativeHandle(ref, () => ({ - checkScrollable: () => { - checkScrollable(); +const ScrollableContainer = forwardRef( + ( + { + children, + maxHeight = '24rem', + className = '', + contentClassName = '', + fadeIndicatorClassName = '', + checkInterval = 100, + scrollThreshold = 5, + debounceDelay = 16, // ~60fps + onScroll, + onScrollStateChange, + ...props }, - scrollToTop: () => { - if (scrollRef.current) { - scrollRef.current.scrollTop = 0; - } - }, - scrollToBottom: () => { - if (scrollRef.current) { - scrollRef.current.scrollTop = scrollRef.current.scrollHeight; - } - }, - getScrollInfo: () => { - if (!scrollRef.current) return null; - const element = scrollRef.current; - return { - scrollTop: element.scrollTop, - scrollHeight: element.scrollHeight, - clientHeight: element.clientHeight, - isScrollable: element.scrollHeight > element.clientHeight, - isAtBottom: element.scrollTop + element.clientHeight >= element.scrollHeight - scrollThreshold + ref, + ) => { + const scrollRef = useRef(null); + const containerRef = useRef(null); + const debounceTimerRef = useRef(null); + const resizeObserverRef = useRef(null); + const onScrollStateChangeRef = useRef(onScrollStateChange); + const onScrollRef = useRef(onScroll); + + const [showScrollHint, setShowScrollHint] = useState(false); + + useEffect(() => { + onScrollStateChangeRef.current = onScrollStateChange; + }, [onScrollStateChange]); + + useEffect(() => { + onScrollRef.current = onScroll; + }, [onScroll]); + + const debounce = useCallback((func, delay) => { + return (...args) => { + if (debounceTimerRef.current) { + clearTimeout(debounceTimerRef.current); + } + debounceTimerRef.current = setTimeout(() => func(...args), delay); }; - } - }), [checkScrollable, scrollThreshold]); + }, []); - useEffect(() => { - const timer = setTimeout(() => { - checkScrollable(); - }, checkInterval); - return () => clearTimeout(timer); - }, [checkScrollable, checkInterval]); + const checkScrollable = useCallback(() => { + if (!scrollRef.current) return; - useEffect(() => { - if (!scrollRef.current) return; + const element = scrollRef.current; + const isScrollable = element.scrollHeight > element.clientHeight; + const isAtBottom = + element.scrollTop + element.clientHeight >= + element.scrollHeight - scrollThreshold; + const shouldShowHint = isScrollable && !isAtBottom; - if (typeof ResizeObserver === 'undefined') { - if (typeof MutationObserver !== 'undefined') { - const observer = new MutationObserver(() => { - debouncedCheckScrollable(); + setShowScrollHint(shouldShowHint); + + if (onScrollStateChangeRef.current) { + onScrollStateChangeRef.current({ + isScrollable, + isAtBottom, + showScrollHint: shouldShowHint, + scrollTop: element.scrollTop, + scrollHeight: element.scrollHeight, + clientHeight: element.clientHeight, }); - - observer.observe(scrollRef.current, { - childList: true, - subtree: true, - attributes: true, - characterData: true - }); - - return () => observer.disconnect(); } - return; - } + }, [scrollThreshold]); - resizeObserverRef.current = new ResizeObserver((entries) => { - for (const entry of entries) { + const debouncedCheckScrollable = useMemo( + () => debounce(checkScrollable, debounceDelay), + [debounce, checkScrollable, debounceDelay], + ); + + const handleScroll = useCallback( + (e) => { debouncedCheckScrollable(); + if (onScrollRef.current) { + onScrollRef.current(e); + } + }, + [debouncedCheckScrollable], + ); + + useImperativeHandle( + ref, + () => ({ + checkScrollable: () => { + checkScrollable(); + }, + scrollToTop: () => { + if (scrollRef.current) { + scrollRef.current.scrollTop = 0; + } + }, + scrollToBottom: () => { + if (scrollRef.current) { + scrollRef.current.scrollTop = scrollRef.current.scrollHeight; + } + }, + getScrollInfo: () => { + if (!scrollRef.current) return null; + const element = scrollRef.current; + return { + scrollTop: element.scrollTop, + scrollHeight: element.scrollHeight, + clientHeight: element.clientHeight, + isScrollable: element.scrollHeight > element.clientHeight, + isAtBottom: + element.scrollTop + element.clientHeight >= + element.scrollHeight - scrollThreshold, + }; + }, + }), + [checkScrollable, scrollThreshold], + ); + + useEffect(() => { + const timer = setTimeout(() => { + checkScrollable(); + }, checkInterval); + return () => clearTimeout(timer); + }, [checkScrollable, checkInterval]); + + useEffect(() => { + if (!scrollRef.current) return; + + if (typeof ResizeObserver === 'undefined') { + if (typeof MutationObserver !== 'undefined') { + const observer = new MutationObserver(() => { + debouncedCheckScrollable(); + }); + + observer.observe(scrollRef.current, { + childList: true, + subtree: true, + attributes: true, + characterData: true, + }); + + return () => observer.disconnect(); + } + return; } - }); - resizeObserverRef.current.observe(scrollRef.current); + resizeObserverRef.current = new ResizeObserver((entries) => { + for (const entry of entries) { + debouncedCheckScrollable(); + } + }); - return () => { - if (resizeObserverRef.current) { - resizeObserverRef.current.disconnect(); - } - }; - }, [debouncedCheckScrollable]); + resizeObserverRef.current.observe(scrollRef.current); - useEffect(() => { - return () => { - if (debounceTimerRef.current) { - clearTimeout(debounceTimerRef.current); - } - }; - }, []); + return () => { + if (resizeObserverRef.current) { + resizeObserverRef.current.disconnect(); + } + }; + }, [debouncedCheckScrollable]); - const containerStyle = useMemo(() => ({ - maxHeight - }), [maxHeight]); + useEffect(() => { + return () => { + if (debounceTimerRef.current) { + clearTimeout(debounceTimerRef.current); + } + }; + }, []); - const fadeIndicatorStyle = useMemo(() => ({ - opacity: showScrollHint ? 1 : 0 - }), [showScrollHint]); + const containerStyle = useMemo( + () => ({ + maxHeight, + }), + [maxHeight], + ); - return ( -
          + const fadeIndicatorStyle = useMemo( + () => ({ + opacity: showScrollHint ? 1 : 0, + }), + [showScrollHint], + ); + + return (
          - {children} +
          + {children} +
          +
          -
          -
          - ); -}); + ); + }, +); ScrollableContainer.displayName = 'ScrollableContainer'; -export default ScrollableContainer; \ No newline at end of file +export default ScrollableContainer; diff --git a/web/src/components/common/ui/SelectableButtonGroup.jsx b/web/src/components/common/ui/SelectableButtonGroup.jsx index ebc900f13..3fe249084 100644 --- a/web/src/components/common/ui/SelectableButtonGroup.jsx +++ b/web/src/components/common/ui/SelectableButtonGroup.jsx @@ -17,10 +17,20 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import React, { useState } from 'react'; -import { useIsMobile } from '../../../hooks/common/useIsMobile'; +import React, { useState, useRef, useEffect } from 'react'; import { useMinimumLoadingTime } from '../../../hooks/common/useMinimumLoadingTime'; -import { Divider, Button, Tag, Row, Col, Collapsible, Checkbox, Skeleton, Tooltip } from '@douyinfe/semi-ui'; +import { useContainerWidth } from '../../../hooks/common/useContainerWidth'; +import { + Divider, + Button, + Tag, + Row, + Col, + Collapsible, + Checkbox, + Skeleton, + Tooltip, +} from '@douyinfe/semi-ui'; import { IconChevronDown, IconChevronUp } from '@douyinfe/semi-icons'; /** @@ -47,22 +57,62 @@ const SelectableButtonGroup = ({ collapsible = true, collapseHeight = 200, withCheckbox = false, - loading = false + loading = false, }) => { const [isOpen, setIsOpen] = useState(false); - const [skeletonCount] = useState(6); - const isMobile = useIsMobile(); - const perRow = 3; + const [skeletonCount] = useState(12); + const [containerRef, containerWidth] = useContainerWidth(); + + const ConditionalTooltipText = ({ text }) => { + const textRef = useRef(null); + const [isOverflowing, setIsOverflowing] = useState(false); + + useEffect(() => { + const el = textRef.current; + if (!el) return; + setIsOverflowing(el.scrollWidth > el.clientWidth); + }, [text, containerWidth]); + + const textElement = ( + + {text} + + ); + + return isOverflowing ? ( + {textElement} + ) : ( + textElement + ); + }; + + // 基于容器宽度计算响应式列数和标签显示策略 + const getResponsiveConfig = () => { + if (containerWidth <= 280) return { columns: 1, showTags: true }; // 极窄:1列+标签 + if (containerWidth <= 380) return { columns: 2, showTags: true }; // 窄屏:2列+标签 + if (containerWidth <= 460) return { columns: 3, showTags: false }; // 中等:3列不加标签 + return { columns: 3, showTags: true }; // 最宽:3列+标签 + }; + + const { columns: perRow, showTags: shouldShowTags } = getResponsiveConfig(); const maxVisibleRows = Math.max(1, Math.floor(collapseHeight / 32)); // Approx row height 32 const needCollapse = collapsible && items.length > perRow * maxVisibleRows; const showSkeleton = useMinimumLoadingTime(loading); + // 统一使用紧凑的网格间距 + const gutterSize = [4, 4]; + + // 计算 Semi UI Col 的 span 值 + const getColSpan = () => { + return Math.floor(24 / perRow); + }; + const maskStyle = isOpen ? {} : { - WebkitMaskImage: - 'linear-gradient(to bottom, black 0%, rgba(0, 0, 0, 1) 60%, rgba(0, 0, 0, 0.2) 80%, transparent 100%)', - }; + WebkitMaskImage: + 'linear-gradient(to bottom, black 0%, rgba(0, 0, 0, 1) 60%, rgba(0, 0, 0, 0.2) 80%, transparent 100%)', + }; const toggle = () => { setIsOpen(!isOpen); @@ -85,28 +135,23 @@ const SelectableButtonGroup = ({ }; const renderSkeletonButtons = () => { - const placeholder = ( - + {Array.from({ length: skeletonCount }).map((_, index) => ( -
          -
          +
          +
          {withCheckbox && ( )} @@ -114,7 +159,7 @@ const SelectableButtonGroup = ({ active style={{ width: `${60 + (index % 3) * 20}px`, - height: 14 + height: 14, }} />
          @@ -128,29 +173,29 @@ const SelectableButtonGroup = ({ ); }; - const contentElement = showSkeleton ? renderSkeletonButtons() : ( - + const contentElement = showSkeleton ? ( + renderSkeletonButtons() + ) : ( + {items.map((item) => { - const isDisabled = item.disabled || (typeof item.tagCount === 'number' && item.tagCount === 0); + const isDisabled = + item.disabled || + (typeof item.tagCount === 'number' && item.tagCount === 0); const isActive = Array.isArray(activeValue) ? activeValue.includes(item.value) : activeValue === item.value; if (withCheckbox) { return ( - + @@ -176,28 +226,27 @@ const SelectableButtonGroup = ({ } return ( - + @@ -208,9 +257,12 @@ const SelectableButtonGroup = ({ ); return ( -
          +
          {title && ( - + {showSkeleton ? ( ) : ( @@ -220,23 +272,30 @@ const SelectableButtonGroup = ({ )} {needCollapse && !showSkeleton ? (
          - + {contentElement} {isOpen ? null : (
          - + {t('展开更多')}
          )} {isOpen && ( -
          - +
          + {t('收起')}
          )} @@ -248,4 +307,4 @@ const SelectableButtonGroup = ({ ); }; -export default SelectableButtonGroup; \ No newline at end of file +export default SelectableButtonGroup; diff --git a/web/src/components/dashboard/AnnouncementsPanel.jsx b/web/src/components/dashboard/AnnouncementsPanel.jsx index e24f8da2f..c62850b3b 100644 --- a/web/src/components/dashboard/AnnouncementsPanel.jsx +++ b/web/src/components/dashboard/AnnouncementsPanel.jsx @@ -21,7 +21,10 @@ import React from 'react'; import { Card, Tag, Timeline, Empty } from '@douyinfe/semi-ui'; import { Bell } from 'lucide-react'; import { marked } from 'marked'; -import { IllustrationConstruction, IllustrationConstructionDark } from '@douyinfe/semi-illustrations'; +import { + IllustrationConstruction, + IllustrationConstructionDark, +} from '@douyinfe/semi-illustrations'; import ScrollableContainer from '../common/ui/ScrollableContainer'; const AnnouncementsPanel = ({ @@ -29,36 +32,43 @@ const AnnouncementsPanel = ({ announcementLegendData, CARD_PROPS, ILLUSTRATION_SIZE, - t + t, }) => { return ( -
          +
          +
          {t('系统公告')} - + {t('显示最新20条')}
          {/* 图例 */} -
          +
          {announcementLegendData.map((legend, index) => ( -
          +
          - {legend.label} + {legend.label}
          ))}
          @@ -66,9 +76,9 @@ const AnnouncementsPanel = ({ } bodyStyle={{ padding: 0 }} > - + {announcementData.length > 0 ? ( - + {announcementData.map((item, idx) => { const htmlExtra = item.extra ? marked.parse(item.extra) : ''; return ( @@ -76,16 +86,20 @@ const AnnouncementsPanel = ({ key={idx} type={item.type || 'default'} time={`${item.relative ? item.relative + ' ' : ''}${item.time}`} - extra={item.extra ? ( -
          - ) : null} + extra={ + item.extra ? ( +
          + ) : null + } >
          @@ -93,10 +107,12 @@ const AnnouncementsPanel = ({ })} ) : ( -
          +
          } - darkModeImage={} + darkModeImage={ + + } title={t('暂无系统公告')} description={t('请联系管理员在系统设置中配置公告信息')} /> @@ -107,4 +123,4 @@ const AnnouncementsPanel = ({ ); }; -export default AnnouncementsPanel; \ No newline at end of file +export default AnnouncementsPanel; diff --git a/web/src/components/dashboard/ApiInfoPanel.jsx b/web/src/components/dashboard/ApiInfoPanel.jsx index 5da250e6e..1c3c3dd3d 100644 --- a/web/src/components/dashboard/ApiInfoPanel.jsx +++ b/web/src/components/dashboard/ApiInfoPanel.jsx @@ -20,7 +20,10 @@ For commercial licensing, please contact support@quantumnous.com import React from 'react'; import { Card, Avatar, Tag, Divider, Empty } from '@douyinfe/semi-ui'; import { Server, Gauge, ExternalLink } from 'lucide-react'; -import { IllustrationConstruction, IllustrationConstructionDark } from '@douyinfe/semi-illustrations'; +import { + IllustrationConstruction, + IllustrationConstructionDark, +} from '@douyinfe/semi-illustrations'; import ScrollableContainer from '../common/ui/ScrollableContainer'; const ApiInfoPanel = ({ @@ -30,12 +33,12 @@ const ApiInfoPanel = ({ CARD_PROPS, FLEX_CENTER_GAP2, ILLUSTRATION_SIZE, - t + t, }) => { return ( @@ -44,66 +47,65 @@ const ApiInfoPanel = ({ } bodyStyle={{ padding: 0 }} > - + {apiInfoData.length > 0 ? ( apiInfoData.map((api) => ( -
          -
          - +
          +
          + {api.route.substring(0, 2)}
          -
          -
          - +
          +
          + {api.route} -
          +
          } - size="small" - color="white" + size='small' + color='white' shape='circle' onClick={() => handleSpeedTest(api.url)} - className="cursor-pointer hover:opacity-80 text-xs" + className='cursor-pointer hover:opacity-80 text-xs' > {t('测速')} } - size="small" - color="white" + size='small' + color='white' shape='circle' - onClick={() => window.open(api.url, '_blank', 'noopener,noreferrer')} - className="cursor-pointer hover:opacity-80 text-xs" + onClick={() => + window.open(api.url, '_blank', 'noopener,noreferrer') + } + className='cursor-pointer hover:opacity-80 text-xs' > {t('跳转')}
          handleCopyUrl(api.url)} > {api.url}
          -
          - {api.description} -
          +
          {api.description}
          )) ) : ( -
          +
          } - darkModeImage={} + darkModeImage={ + + } title={t('暂无API信息')} description={t('请联系管理员在系统设置中配置API信息')} /> @@ -114,4 +116,4 @@ const ApiInfoPanel = ({ ); }; -export default ApiInfoPanel; \ No newline at end of file +export default ApiInfoPanel; diff --git a/web/src/components/dashboard/ChartsPanel.jsx b/web/src/components/dashboard/ChartsPanel.jsx index 595e2e029..0992adace 100644 --- a/web/src/components/dashboard/ChartsPanel.jsx +++ b/web/src/components/dashboard/ChartsPanel.jsx @@ -20,11 +20,6 @@ For commercial licensing, please contact support@quantumnous.com import React from 'react'; import { Card, Tabs, TabPane } from '@douyinfe/semi-ui'; import { PieChart } from 'lucide-react'; -import { - IconHistogram, - IconPulse, - IconPieChart2Stroked -} from '@douyinfe/semi-icons'; import { VChart } from '@visactor/react-vchart'; const ChartsPanel = ({ @@ -38,80 +33,48 @@ const ChartsPanel = ({ CHART_CONFIG, FLEX_CENTER_GAP2, hasApiInfoPanel, - t + t, }) => { return ( +
          {t('模型数据分析')}
          - - - {t('消耗分布')} - - } itemKey="1" /> - - - {t('消耗趋势')} - - } itemKey="2" /> - - - {t('调用次数分布')} - - } itemKey="3" /> - - - {t('调用次数排行')} - - } itemKey="4" /> + {t('消耗分布')}} itemKey='1' /> + {t('消耗趋势')}} itemKey='2' /> + {t('调用次数分布')}} itemKey='3' /> + {t('调用次数排行')}} itemKey='4' />
          } bodyStyle={{ padding: 0 }} > -
          +
          {activeChartTab === '1' && ( - + )} {activeChartTab === '2' && ( - + )} {activeChartTab === '3' && ( - + )} {activeChartTab === '4' && ( - + )}
          ); }; -export default ChartsPanel; \ No newline at end of file +export default ChartsPanel; diff --git a/web/src/components/dashboard/DashboardHeader.jsx b/web/src/components/dashboard/DashboardHeader.jsx index e0be5d859..c2867e90c 100644 --- a/web/src/components/dashboard/DashboardHeader.jsx +++ b/web/src/components/dashboard/DashboardHeader.jsx @@ -27,19 +27,19 @@ const DashboardHeader = ({ showSearchModal, refresh, loading, - t + t, }) => { - const ICON_BUTTON_CLASS = "text-white hover:bg-opacity-80 !rounded-full"; + const ICON_BUTTON_CLASS = 'text-white hover:bg-opacity-80 !rounded-full'; return ( -
          +

          {getGreeting}

          -
          +
          - + footer={ +
          + +
          - )} + } size={isMobile ? 'full-width' : 'large'} > {renderBody()} @@ -206,4 +252,4 @@ const NoticeModal = ({ visible, onClose, isMobile, defaultTab = 'inApp', unreadK ); }; -export default NoticeModal; \ No newline at end of file +export default NoticeModal; diff --git a/web/src/components/layout/PageLayout.jsx b/web/src/components/layout/PageLayout.jsx index dd5080687..f8cdfb0cb 100644 --- a/web/src/components/layout/PageLayout.jsx +++ b/web/src/components/layout/PageLayout.jsx @@ -17,7 +17,7 @@ along with this program. If not, see . For commercial licensing, please contact support@quantumnous.com */ -import HeaderBar from './HeaderBar'; +import HeaderBar from './headerbar'; import { Layout } from '@douyinfe/semi-ui'; import SiderBar from './SiderBar'; import App from '../../App'; @@ -27,7 +27,13 @@ import React, { useContext, useEffect, useState } from 'react'; import { useIsMobile } from '../../hooks/common/useIsMobile'; import { useSidebarCollapsed } from '../../hooks/common/useSidebarCollapsed'; import { useTranslation } from 'react-i18next'; -import { API, getLogo, getSystemName, showError, setStatusData } from '../../helpers'; +import { + API, + getLogo, + getSystemName, + showError, + setStatusData, +} from '../../helpers'; import { UserContext } from '../../context/User'; import { StatusContext } from '../../context/Status'; import { useLocation } from 'react-router-dom'; @@ -42,9 +48,12 @@ const PageLayout = () => { const { i18n } = useTranslation(); const location = useLocation(); - const shouldHideFooter = location.pathname.startsWith('/console') || location.pathname === '/pricing'; + const shouldHideFooter = + location.pathname.startsWith('/console') || + location.pathname === '/pricing'; - const shouldInnerPadding = location.pathname.includes('/console') && + const shouldInnerPadding = + location.pathname.includes('/console') && !location.pathname.startsWith('/console/chat') && location.pathname !== '/console/playground'; @@ -120,7 +129,10 @@ const PageLayout = () => { zIndex: 100, }} > - setDrawerOpen(prev => !prev)} drawerOpen={drawerOpen} /> + setDrawerOpen((prev) => !prev)} + drawerOpen={drawerOpen} + /> { width: 'var(--sidebar-current-width)', }} > - { if (isMobile) setDrawerOpen(false); }} /> + { + if (isMobile) setDrawerOpen(false); + }} + /> )} { const location = useLocation(); useEffect(() => { - if (statusState?.status?.setup === false && location.pathname !== '/setup') { + if ( + statusState?.status?.setup === false && + location.pathname !== '/setup' + ) { window.location.href = '/setup'; } }, [statusState?.status?.setup, location.pathname]); @@ -34,4 +37,4 @@ const SetupCheck = ({ children }) => { return children; }; -export default SetupCheck; \ No newline at end of file +export default SetupCheck; diff --git a/web/src/components/layout/SiderBar.jsx b/web/src/components/layout/SiderBar.jsx index 86c480022..793e48355 100644 --- a/web/src/components/layout/SiderBar.jsx +++ b/web/src/components/layout/SiderBar.jsx @@ -23,17 +23,12 @@ import { useTranslation } from 'react-i18next'; import { getLucideIcon } from '../../helpers/render'; import { ChevronLeft } from 'lucide-react'; import { useSidebarCollapsed } from '../../hooks/common/useSidebarCollapsed'; -import { - isAdmin, - isRoot, - showError -} from '../../helpers'; +import { useSidebar } from '../../hooks/common/useSidebar'; +import { useMinimumLoadingTime } from '../../hooks/common/useMinimumLoadingTime'; +import { isAdmin, isRoot, showError } from '../../helpers'; +import SkeletonWrapper from './components/SkeletonWrapper'; -import { - Nav, - Divider, - Button, -} from '@douyinfe/semi-ui'; +import { Nav, Divider, Button } from '@douyinfe/semi-ui'; const routerMap = { home: '/', @@ -54,9 +49,16 @@ const routerMap = { personal: '/console/personal', }; -const SiderBar = ({ onNavigate = () => { } }) => { +const SiderBar = ({ onNavigate = () => {} }) => { const { t } = useTranslation(); const [collapsed, toggleCollapsed] = useSidebarCollapsed(); + const { + isModuleVisible, + hasSectionVisibleModules, + loading: sidebarLoading, + } = useSidebar(); + + const showSkeleton = useMinimumLoadingTime(sidebarLoading); const [selectedKeys, setSelectedKeys] = useState(['home']); const [chatItems, setChatItems] = useState([]); @@ -64,8 +66,8 @@ const SiderBar = ({ onNavigate = () => { } }) => { const location = useLocation(); const [routerMapState, setRouterMapState] = useState(routerMap); - const workspaceItems = useMemo( - () => [ + const workspaceItems = useMemo(() => { + const items = [ { text: t('数据看板'), itemKey: 'detail', @@ -101,17 +103,25 @@ const SiderBar = ({ onNavigate = () => { } }) => { className: localStorage.getItem('enable_task') === 'true' ? '' : 'tableHiddle', }, - ], - [ - localStorage.getItem('enable_data_export'), - localStorage.getItem('enable_drawing'), - localStorage.getItem('enable_task'), - t, - ], - ); + ]; - const financeItems = useMemo( - () => [ + // 根据配置过滤项目 + const filteredItems = items.filter((item) => { + const configVisible = isModuleVisible('console', item.itemKey); + return configVisible; + }); + + return filteredItems; + }, [ + localStorage.getItem('enable_data_export'), + localStorage.getItem('enable_drawing'), + localStorage.getItem('enable_task'), + t, + isModuleVisible, + ]); + + const financeItems = useMemo(() => { + const items = [ { text: t('钱包管理'), itemKey: 'topup', @@ -122,12 +132,19 @@ const SiderBar = ({ onNavigate = () => { } }) => { itemKey: 'personal', to: '/personal', }, - ], - [t], - ); + ]; - const adminItems = useMemo( - () => [ + // 根据配置过滤项目 + const filteredItems = items.filter((item) => { + const configVisible = isModuleVisible('personal', item.itemKey); + return configVisible; + }); + + return filteredItems; + }, [t, isModuleVisible]); + + const adminItems = useMemo(() => { + const items = [ { text: t('渠道管理'), itemKey: 'channel', @@ -158,12 +175,19 @@ const SiderBar = ({ onNavigate = () => { } }) => { to: '/setting', className: isRoot() ? '' : 'tableHiddle', }, - ], - [isAdmin(), isRoot(), t], - ); + ]; - const chatMenuItems = useMemo( - () => [ + // 根据配置过滤项目 + const filteredItems = items.filter((item) => { + const configVisible = isModuleVisible('admin', item.itemKey); + return configVisible; + }); + + return filteredItems; + }, [isAdmin(), isRoot(), t, isModuleVisible]); + + const chatMenuItems = useMemo(() => { + const items = [ { text: t('操练场'), itemKey: 'playground', @@ -174,9 +198,16 @@ const SiderBar = ({ onNavigate = () => { } }) => { itemKey: 'chat', items: chatItems, }, - ], - [chatItems, t], - ); + ]; + + // 根据配置过滤项目 + const filteredItems = items.filter((item) => { + const configVisible = isModuleVisible('chat', item.itemKey); + return configVisible; + }); + + return filteredItems; + }, [chatItems, t, isModuleVisible]); // 更新路由映射,添加聊天路由 const updateRouterMapWithChats = (chats) => { @@ -221,7 +252,6 @@ const SiderBar = ({ onNavigate = () => { } }) => { updateRouterMapWithChats(chats); } } catch (e) { - console.error(e); showError('聊天数据解析失败'); } } @@ -275,14 +305,15 @@ const SiderBar = ({ onNavigate = () => { } }) => { key={item.itemKey} itemKey={item.itemKey} text={ -
          - - {item.text} - -
          + + {item.text} + } icon={ -
          +
          {getLucideIcon(item.itemKey, isSelected)}
          } @@ -302,14 +333,15 @@ const SiderBar = ({ onNavigate = () => { } }) => { key={item.itemKey} itemKey={item.itemKey} text={ -
          - - {item.text} - -
          + + {item.text} + } icon={ -
          +
          {getLucideIcon(item.itemKey, isSelected)}
          } @@ -323,7 +355,10 @@ const SiderBar = ({ onNavigate = () => { } }) => { key={subItem.itemKey} itemKey={subItem.itemKey} text={ - + {subItem.text} } @@ -339,107 +374,143 @@ const SiderBar = ({ onNavigate = () => { } }) => { return (
          - + {/* 底部折叠按钮 */} -
          - + +
          ); diff --git a/web/src/components/layout/components/SkeletonWrapper.jsx b/web/src/components/layout/components/SkeletonWrapper.jsx new file mode 100644 index 000000000..7fbd588ca --- /dev/null +++ b/web/src/components/layout/components/SkeletonWrapper.jsx @@ -0,0 +1,394 @@ +/* +Copyright (C) 2025 QuantumNous + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as +published by the Free Software Foundation, either version 3 of the +License, or (at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see . + +For commercial licensing, please contact support@quantumnous.com +*/ + +import React from 'react'; +import { Skeleton } from '@douyinfe/semi-ui'; + +const SkeletonWrapper = ({ + loading = false, + type = 'text', + count = 1, + width = 60, + height = 16, + isMobile = false, + className = '', + collapsed = false, + showAdmin = true, + children, + ...props +}) => { + if (!loading) { + return children; + } + + // 导航链接骨架屏 + const renderNavigationSkeleton = () => { + const skeletonLinkClasses = isMobile + ? 'flex items-center gap-1 p-1 w-full rounded-md' + : 'flex items-center gap-1 p-2 rounded-md'; + + return Array(count) + .fill(null) + .map((_, index) => ( +
          + + } + /> +
          + )); + }; + + // 用户区域骨架屏 (头像 + 文本) + const renderUserAreaSkeleton = () => { + return ( +
          + + } + /> +
          + + } + /> +
          +
          + ); + }; + + // Logo图片骨架屏 + const renderImageSkeleton = () => { + return ( + + } + /> + ); + }; + + // 系统名称骨架屏 + const renderTitleSkeleton = () => { + return ( + } + /> + ); + }; + + // 通用文本骨架屏 + const renderTextSkeleton = () => { + return ( +
          + } + /> +
          + ); + }; + + // 按钮骨架屏(支持圆角) + const renderButtonSkeleton = () => { + return ( +
          + + } + /> +
          + ); + }; + + // 侧边栏导航项骨架屏 (图标 + 文本) + const renderSidebarNavItemSkeleton = () => { + return Array(count) + .fill(null) + .map((_, index) => ( +
          + {/* 图标骨架屏 */} +
          + + } + /> +
          + {/* 文本骨架屏 */} + + } + /> +
          + )); + }; + + // 侧边栏组标题骨架屏 + const renderSidebarGroupTitleSkeleton = () => { + return ( +
          + + } + /> +
          + ); + }; + + // 完整侧边栏骨架屏 - 1:1 还原,去重实现 + const renderSidebarSkeleton = () => { + const NAV_WIDTH = 164; + const NAV_HEIGHT = 30; + const COLLAPSED_WIDTH = 44; + const COLLAPSED_HEIGHT = 44; + const ICON_SIZE = 16; + const TITLE_HEIGHT = 12; + const TEXT_HEIGHT = 16; + + const renderIcon = () => ( + + } + /> + ); + + const renderLabel = (labelWidth) => ( + + } + /> + ); + + const NavRow = ({ labelWidth }) => ( +
          +
          + {renderIcon()} +
          + {renderLabel(labelWidth)} +
          + ); + + const CollapsedRow = ({ keyPrefix, index }) => ( +
          + + } + /> +
          + ); + + if (collapsed) { + return ( +
          + {Array(2) + .fill(null) + .map((_, i) => ( + + ))} + {Array(5) + .fill(null) + .map((_, i) => ( + + ))} + {Array(2) + .fill(null) + .map((_, i) => ( + + ))} + {Array(5) + .fill(null) + .map((_, i) => ( + + ))} +
          + ); + } + + const sections = [ + { key: 'chat', titleWidth: 32, itemWidths: [54, 32], wrapper: 'section' }, + { key: 'console', titleWidth: 48, itemWidths: [64, 64, 64, 64, 64] }, + { key: 'personal', titleWidth: 64, itemWidths: [64, 64] }, + ...(showAdmin + ? [{ key: 'admin', titleWidth: 48, itemWidths: [64, 64, 80, 64, 64] }] + : []), + ]; + + return ( +
          + {sections.map((sec, idx) => ( + + {sec.wrapper === 'section' ? ( +
          +
          + + } + /> +
          + {sec.itemWidths.map((w, i) => ( + + ))} +
          + ) : ( +
          +
          + + } + /> +
          + {sec.itemWidths.map((w, i) => ( + + ))} +
          + )} +
          + ))} +
          + ); + }; + + // 根据类型渲染不同的骨架屏 + switch (type) { + case 'navigation': + return renderNavigationSkeleton(); + case 'userArea': + return renderUserAreaSkeleton(); + case 'image': + return renderImageSkeleton(); + case 'title': + return renderTitleSkeleton(); + case 'sidebarNavItem': + return renderSidebarNavItemSkeleton(); + case 'sidebarGroupTitle': + return renderSidebarGroupTitleSkeleton(); + case 'sidebar': + return renderSidebarSkeleton(); + case 'button': + return renderButtonSkeleton(); + case 'text': + default: + return renderTextSkeleton(); + } +}; + +export default SkeletonWrapper; diff --git a/web/src/components/layout/HeaderBar/ActionButtons.jsx b/web/src/components/layout/headerbar/ActionButtons.jsx similarity index 91% rename from web/src/components/layout/HeaderBar/ActionButtons.jsx rename to web/src/components/layout/headerbar/ActionButtons.jsx index 5717fbff7..545b5227b 100644 --- a/web/src/components/layout/HeaderBar/ActionButtons.jsx +++ b/web/src/components/layout/headerbar/ActionButtons.jsx @@ -41,7 +41,7 @@ const ActionButtons = ({ t, }) => { return ( -
          +
          - + -
          - + +
          + logo
          -
          -
          +
          +
          - + {systemName} {(isSelfUseMode || isDemoSiteMode) && !isLoading && ( {isSelfUseMode ? t('自用模式') : t('演示站点')} diff --git a/web/src/components/layout/HeaderBar/LanguageSelector.jsx b/web/src/components/layout/headerbar/LanguageSelector.jsx similarity index 85% rename from web/src/components/layout/HeaderBar/LanguageSelector.jsx rename to web/src/components/layout/headerbar/LanguageSelector.jsx index 68bdcb012..cbfd69b35 100644 --- a/web/src/components/layout/HeaderBar/LanguageSelector.jsx +++ b/web/src/components/layout/headerbar/LanguageSelector.jsx @@ -25,21 +25,21 @@ import { CN, GB } from 'country-flag-icons/react/3x2'; const LanguageSelector = ({ currentLang, onLanguageChange, t }) => { return ( + onLanguageChange('zh')} className={`!flex !items-center !gap-2 !px-3 !py-1.5 !text-sm !text-semi-color-text-0 dark:!text-gray-200 ${currentLang === 'zh' ? '!bg-semi-color-primary-light-default dark:!bg-blue-600 !font-semibold' : 'hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-600'}`} > - + 中文 onLanguageChange('en')} className={`!flex !items-center !gap-2 !px-3 !py-1.5 !text-sm !text-semi-color-text-0 dark:!text-gray-200 ${currentLang === 'en' ? '!bg-semi-color-primary-light-default dark:!bg-blue-600 !font-semibold' : 'hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-600'}`} > - + English @@ -48,9 +48,9 @@ const LanguageSelector = ({ currentLang, onLanguageChange, t }) => { ); } else { const showRegisterButton = !isSelfUseMode; - const commonSizingAndLayoutClass = "flex items-center justify-center !py-[10px] !px-1.5"; + const commonSizingAndLayoutClass = + 'flex items-center justify-center !py-[10px] !px-1.5'; - const loginButtonSpecificStyling = "!bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-700 transition-colors"; + const loginButtonSpecificStyling = + '!bg-semi-color-fill-0 dark:!bg-semi-color-fill-1 hover:!bg-semi-color-fill-1 dark:hover:!bg-gray-700 transition-colors'; let loginButtonClasses = `${commonSizingAndLayoutClass} ${loginButtonSpecificStyling}`; let registerButtonClasses = `${commonSizingAndLayoutClass}`; - const loginButtonTextSpanClass = "!text-xs !text-semi-color-text-1 dark:!text-gray-300 !p-1.5"; - const registerButtonTextSpanClass = "!text-xs !text-white !p-1.5"; + const loginButtonTextSpanClass = + '!text-xs !text-semi-color-text-1 dark:!text-gray-300 !p-1.5'; + const registerButtonTextSpanClass = '!text-xs !text-white !p-1.5'; if (showRegisterButton) { if (isMobile) { - loginButtonClasses += " !rounded-full"; + loginButtonClasses += ' !rounded-full'; } else { - loginButtonClasses += " !rounded-l-full !rounded-r-none"; + loginButtonClasses += ' !rounded-l-full !rounded-r-none'; } - registerButtonClasses += " !rounded-r-full !rounded-l-none"; + registerButtonClasses += ' !rounded-r-full !rounded-l-none'; } else { - loginButtonClasses += " !rounded-full"; + loginButtonClasses += ' !rounded-full'; } return ( -
          - +
          + {showRegisterButton && ( -
          - +
          +
          diff --git a/web/src/components/layout/HeaderBar/index.jsx b/web/src/components/layout/headerbar/index.jsx similarity index 89% rename from web/src/components/layout/HeaderBar/index.jsx rename to web/src/components/layout/headerbar/index.jsx index 0a0e89545..81b51d7fe 100644 --- a/web/src/components/layout/HeaderBar/index.jsx +++ b/web/src/components/layout/headerbar/index.jsx @@ -44,6 +44,8 @@ const HeaderBar = ({ onMobileMenuToggle, drawerOpen }) => { isDemoSiteMode, isConsoleRoute, theme, + headerNavModules, + pricingRequireAuth, logout, handleLanguageChange, handleThemeToggle, @@ -60,10 +62,10 @@ const HeaderBar = ({ onMobileMenuToggle, drawerOpen }) => { getUnreadKeys, } = useNotifications(statusState); - const { mainNavLinks } = useNavigation(t, docsLink); + const { mainNavLinks } = useNavigation(t, docsLink, headerNavModules); return ( -
          +
          { unreadKeys={getUnreadKeys()} /> -
          -
          -
          +
          +
          +
          { isMobile={isMobile} isLoading={isLoading} userState={userState} + pricingRequireAuth={pricingRequireAuth} /> {/* 聊天头部 */} {styleState.isMobile ? ( -
          +
          ) : ( -
          -
          -
          -
          - +
          +
          +
          +
          +
          - + {t('AI 对话')} - + {inputs.model || t('选择模型开始对话')}
          -
          +
          @@ -97,7 +94,7 @@ const ChatArea = ({ )} {/* 聊天内容区域 */} -
          +
          @@ -129,4 +126,4 @@ const ChatArea = ({ ); }; -export default ChatArea; \ No newline at end of file +export default ChatArea; diff --git a/web/src/components/playground/CodeViewer.jsx b/web/src/components/playground/CodeViewer.jsx index 0e0d0bf54..9d8ae453a 100644 --- a/web/src/components/playground/CodeViewer.jsx +++ b/web/src/components/playground/CodeViewer.jsx @@ -102,15 +102,17 @@ const highlightJson = (str) => { color = '#569cd6'; } return `${match}`; - } + }, ); }; const isJsonLike = (content, language) => { if (language === 'json') return true; const trimmed = content.trim(); - return (trimmed.startsWith('{') && trimmed.endsWith('}')) || - (trimmed.startsWith('[') && trimmed.endsWith(']')); + return ( + (trimmed.startsWith('{') && trimmed.endsWith('}')) || + (trimmed.startsWith('[') && trimmed.endsWith(']')) + ); }; const formatContent = (content) => { @@ -148,7 +150,10 @@ const CodeViewer = ({ content, title, language = 'json' }) => { const contentMetrics = useMemo(() => { const length = formattedContent.length; const isLarge = length > PERFORMANCE_CONFIG.MAX_DISPLAY_LENGTH; - const isVeryLarge = length > PERFORMANCE_CONFIG.MAX_DISPLAY_LENGTH * PERFORMANCE_CONFIG.VERY_LARGE_MULTIPLIER; + const isVeryLarge = + length > + PERFORMANCE_CONFIG.MAX_DISPLAY_LENGTH * + PERFORMANCE_CONFIG.VERY_LARGE_MULTIPLIER; return { length, isLarge, isVeryLarge }; }, [formattedContent.length]); @@ -156,8 +161,10 @@ const CodeViewer = ({ content, title, language = 'json' }) => { if (!contentMetrics.isLarge || isExpanded) { return formattedContent; } - return formattedContent.substring(0, PERFORMANCE_CONFIG.PREVIEW_LENGTH) + - '\n\n// ... 内容被截断以提升性能 ...'; + return ( + formattedContent.substring(0, PERFORMANCE_CONFIG.PREVIEW_LENGTH) + + '\n\n// ... 内容被截断以提升性能 ...' + ); }, [formattedContent, contentMetrics.isLarge, isExpanded]); const highlightedContent = useMemo(() => { @@ -174,9 +181,10 @@ const CodeViewer = ({ content, title, language = 'json' }) => { const handleCopy = useCallback(async () => { try { - const textToCopy = typeof content === 'object' && content !== null - ? JSON.stringify(content, null, 2) - : content; + const textToCopy = + typeof content === 'object' && content !== null + ? JSON.stringify(content, null, 2) + : content; const success = await copy(textToCopy); setCopied(true); @@ -205,11 +213,12 @@ const CodeViewer = ({ content, title, language = 'json' }) => { }, [isExpanded, contentMetrics.isVeryLarge]); if (!content) { - const placeholderText = { - preview: t('正在构造请求体预览...'), - request: t('暂无请求数据'), - response: t('暂无响应数据') - }[title] || t('暂无数据'); + const placeholderText = + { + preview: t('正在构造请求体预览...'), + request: t('暂无请求数据'), + response: t('暂无响应数据'), + }[title] || t('暂无数据'); return (
          @@ -222,7 +231,7 @@ const CodeViewer = ({ content, title, language = 'json' }) => { const contentPadding = contentMetrics.isLarge ? '52px' : '16px'; return ( -
          +
          {/* 性能警告 */} {contentMetrics.isLarge && (
          @@ -250,8 +259,8 @@ const CodeViewer = ({ content, title, language = 'json' }) => { @@ -329,4 +354,4 @@ const CodeViewer = ({ content, title, language = 'json' }) => { ); }; -export default CodeViewer; \ No newline at end of file +export default CodeViewer; diff --git a/web/src/components/playground/ConfigManager.jsx b/web/src/components/playground/ConfigManager.jsx index 753d11380..7eaa35b8a 100644 --- a/web/src/components/playground/ConfigManager.jsx +++ b/web/src/components/playground/ConfigManager.jsx @@ -18,21 +18,16 @@ For commercial licensing, please contact support@quantumnous.com */ import React, { useRef } from 'react'; -import { - Button, - Typography, - Toast, - Modal, - Dropdown, -} from '@douyinfe/semi-ui'; -import { - Download, - Upload, - RotateCcw, - Settings2, -} from 'lucide-react'; +import { Button, Typography, Toast, Modal, Dropdown } from '@douyinfe/semi-ui'; +import { Download, Upload, RotateCcw, Settings2 } from 'lucide-react'; import { useTranslation } from 'react-i18next'; -import { exportConfig, importConfig, clearConfig, hasStoredConfig, getConfigTimestamp } from './configStorage'; +import { + exportConfig, + importConfig, + clearConfig, + hasStoredConfig, + getConfigTimestamp, +} from './configStorage'; const ConfigManager = ({ currentConfig, @@ -51,7 +46,10 @@ const ConfigManager = ({ ...currentConfig, timestamp: new Date().toISOString(), }; - localStorage.setItem('playground_config', JSON.stringify(configWithTimestamp)); + localStorage.setItem( + 'playground_config', + JSON.stringify(configWithTimestamp), + ); exportConfig(currentConfig, messages); Toast.success({ @@ -104,7 +102,9 @@ const ConfigManager = ({ const handleReset = () => { Modal.confirm({ title: t('重置配置'), - content: t('将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?'), + content: t( + '将清除所有保存的配置并恢复默认设置,此操作不可撤销。是否继续?', + ), okText: t('确定重置'), cancelText: t('取消'), okButtonProps: { @@ -114,7 +114,9 @@ const ConfigManager = ({ // 询问是否同时重置消息 Modal.confirm({ title: t('重置选项'), - content: t('是否同时重置对话消息?选择"是"将清空所有对话记录并恢复默认示例;选择"否"将保留当前对话记录。'), + content: t( + '是否同时重置对话消息?选择"是"将清空所有对话记录并恢复默认示例;选择"否"将保留当前对话记录。', + ), okText: t('同时重置消息'), cancelText: t('仅重置配置'), okButtonProps: { @@ -159,7 +161,7 @@ const ConfigManager = ({ name: 'export', onClick: handleExport, children: ( -
          +
          {t('导出配置')}
          @@ -170,7 +172,7 @@ const ConfigManager = ({ name: 'import', onClick: handleImportClick, children: ( -
          +
          {t('导入配置')}
          @@ -184,7 +186,7 @@ const ConfigManager = ({ name: 'reset', onClick: handleReset, children: ( -
          +
          {t('重置配置')}
          @@ -197,24 +199,24 @@ const ConfigManager = ({ return ( <>
          {/* 导出和导入按钮 */} -
          +
          @@ -267,8 +269,8 @@ const ConfigManager = ({ @@ -276,4 +278,4 @@ const ConfigManager = ({ ); }; -export default ConfigManager; \ No newline at end of file +export default ConfigManager; diff --git a/web/src/components/playground/CustomInputRender.jsx b/web/src/components/playground/CustomInputRender.jsx index 2191cb165..464cfa3b1 100644 --- a/web/src/components/playground/CustomInputRender.jsx +++ b/web/src/components/playground/CustomInputRender.jsx @@ -21,23 +21,24 @@ import React from 'react'; const CustomInputRender = (props) => { const { detailProps } = props; - const { clearContextNode, uploadNode, inputNode, sendNode, onClick } = detailProps; + const { clearContextNode, uploadNode, inputNode, sendNode, onClick } = + detailProps; // 清空按钮 const styledClearNode = clearContextNode ? React.cloneElement(clearContextNode, { - className: `!rounded-full !bg-gray-100 hover:!bg-red-500 hover:!text-white flex-shrink-0 transition-all ${clearContextNode.props.className || ''}`, - style: { - ...clearContextNode.props.style, - width: '32px', - height: '32px', - minWidth: '32px', - padding: 0, - display: 'flex', - alignItems: 'center', - justifyContent: 'center', - } - }) + className: `!rounded-full !bg-gray-100 hover:!bg-red-500 hover:!text-white flex-shrink-0 transition-all ${clearContextNode.props.className || ''}`, + style: { + ...clearContextNode.props.style, + width: '32px', + height: '32px', + minWidth: '32px', + padding: 0, + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + }, + }) : null; // 发送按钮 @@ -52,21 +53,19 @@ const CustomInputRender = (props) => { display: 'flex', alignItems: 'center', justifyContent: 'center', - } + }, }); return ( -
          +
          {/* 清空对话按钮 - 左边 */} {styledClearNode} -
          - {inputNode} -
          +
          {inputNode}
          {/* 发送按钮 - 右边 */} {styledSendNode}
          @@ -74,4 +73,4 @@ const CustomInputRender = (props) => { ); }; -export default CustomInputRender; \ No newline at end of file +export default CustomInputRender; diff --git a/web/src/components/playground/CustomRequestEditor.jsx b/web/src/components/playground/CustomRequestEditor.jsx index e411d9e78..26b3ff504 100644 --- a/web/src/components/playground/CustomRequestEditor.jsx +++ b/web/src/components/playground/CustomRequestEditor.jsx @@ -25,13 +25,7 @@ import { Switch, Banner, } from '@douyinfe/semi-ui'; -import { - Code, - Edit, - Check, - X, - AlertTriangle, -} from 'lucide-react'; +import { Code, Edit, Check, X, AlertTriangle } from 'lucide-react'; import { useTranslation } from 'react-i18next'; const CustomRequestEditor = ({ @@ -48,12 +42,22 @@ const CustomRequestEditor = ({ // 当切换到自定义模式时,用默认payload初始化 useEffect(() => { - if (customRequestMode && (!customRequestBody || customRequestBody.trim() === '')) { - const defaultJson = defaultPayload ? JSON.stringify(defaultPayload, null, 2) : ''; + if ( + customRequestMode && + (!customRequestBody || customRequestBody.trim() === '') + ) { + const defaultJson = defaultPayload + ? JSON.stringify(defaultPayload, null, 2) + : ''; setLocalValue(defaultJson); onCustomRequestBodyChange(defaultJson); } - }, [customRequestMode, defaultPayload, customRequestBody, onCustomRequestBodyChange]); + }, [ + customRequestMode, + defaultPayload, + customRequestBody, + onCustomRequestBodyChange, + ]); // 同步外部传入的customRequestBody到本地状态 useEffect(() => { @@ -113,21 +117,21 @@ const CustomRequestEditor = ({ }; return ( -
          +
          {/* 自定义模式开关 */} -
          -
          - - +
          +
          + + 自定义请求体模式
          @@ -135,43 +139,43 @@ const CustomRequestEditor = ({ <> {/* 提示信息 */} } - className="!rounded-lg" + className='!rounded-lg' closeIcon={null} /> {/* JSON编辑器 */}
          -
          - +
          + 请求体 JSON -
          +
          {isValid ? ( -
          +
          - + 格式正确
          ) : ( -
          +
          - + 格式错误
          )} @@ -191,12 +195,12 @@ const CustomRequestEditor = ({ /> {!isValid && errorMessage && ( - + {errorMessage} )} - + 请输入有效的JSON格式的请求体。您可以参考预览面板中的默认请求体格式。
          @@ -206,4 +210,4 @@ const CustomRequestEditor = ({ ); }; -export default CustomRequestEditor; \ No newline at end of file +export default CustomRequestEditor; diff --git a/web/src/components/playground/DebugPanel.jsx b/web/src/components/playground/DebugPanel.jsx index 24158c2b2..d931ff61c 100644 --- a/web/src/components/playground/DebugPanel.jsx +++ b/web/src/components/playground/DebugPanel.jsx @@ -26,14 +26,7 @@ import { Button, Dropdown, } from '@douyinfe/semi-ui'; -import { - Code, - Zap, - Clock, - X, - Eye, - Send, -} from 'lucide-react'; +import { Code, Zap, Clock, X, Eye, Send } from 'lucide-react'; import { useTranslation } from 'react-i18next'; import CodeViewer from './CodeViewer'; @@ -76,7 +69,7 @@ const DebugPanel = ({ - {items.map(item => { + {items.map((item) => { return ( -
          -
          -
          - +
          +
          +
          +
          - + {t('调试信息')}
          @@ -127,75 +120,84 @@ const DebugPanel = ({
          -
          +
          - - - {t('预览请求体')} - {customRequestMode && ( - - 自定义 - - )} -
          - } itemKey="preview"> + + + {t('预览请求体')} + {customRequestMode && ( + + 自定义 + + )} +
          + } + itemKey='preview' + > - - - {t('实际请求体')} -
          - } itemKey="request"> + + + {t('实际请求体')} +
          + } + itemKey='request' + > - - - {t('响应')} -
          - } itemKey="response"> + + + {t('响应')} +
          + } + itemKey='response' + >
          -
          +
          {(debugData.timestamp || debugData.previewTimestamp) && ( -
          - - +
          + + {activeKey === 'preview' && debugData.previewTimestamp ? `${t('预览更新')}: ${new Date(debugData.previewTimestamp).toLocaleString()}` : debugData.timestamp @@ -209,4 +211,4 @@ const DebugPanel = ({ ); }; -export default DebugPanel; \ No newline at end of file +export default DebugPanel; diff --git a/web/src/components/playground/FloatingButtons.jsx b/web/src/components/playground/FloatingButtons.jsx index 87a3b0b55..3d024df4c 100644 --- a/web/src/components/playground/FloatingButtons.jsx +++ b/web/src/components/playground/FloatingButtons.jsx @@ -19,11 +19,7 @@ For commercial licensing, please contact support@quantumnous.com import React from 'react'; import { Button } from '@douyinfe/semi-ui'; -import { - Settings, - Eye, - EyeOff, -} from 'lucide-react'; +import { Settings, Eye, EyeOff } from 'lucide-react'; const FloatingButtons = ({ styleState, @@ -55,7 +51,7 @@ const FloatingButtons = ({ onClick={onToggleSettings} theme='solid' type='primary' - className="lg:hidden" + className='lg:hidden' /> )} @@ -64,8 +60,8 @@ const FloatingButtons = ({
          {!imageEnabled ? ( - - {disabled ? '图片功能在自定义请求体模式下不可用' : '启用后可添加图片URL进行多模态对话'} + + {disabled + ? '图片功能在自定义请求体模式下不可用' + : '启用后可添加图片URL进行多模态对话'} ) : imageUrls.length === 0 ? ( - - {disabled ? '图片功能在自定义请求体模式下不可用' : '点击 + 按钮添加图片URL进行多模态对话'} + + {disabled + ? '图片功能在自定义请求体模式下不可用' + : '点击 + 按钮添加图片URL进行多模态对话'} ) : ( - - 已添加 {imageUrls.length} 张图片{disabled ? ' (自定义模式下不可用)' : ''} + + 已添加 {imageUrls.length} 张图片 + {disabled ? ' (自定义模式下不可用)' : ''} )} -
          +
          {imageUrls.map((url, index) => ( -
          -
          +
          +
          handleUpdateImageUrl(index, value)} - className="!rounded-lg" - size="small" + className='!rounded-lg' + size='small' prefix={} disabled={!imageEnabled || disabled} />
          @@ -129,4 +137,4 @@ const ImageUrlInput = ({ imageUrls, imageEnabled, onImageUrlsChange, onImageEnab ); }; -export default ImageUrlInput; \ No newline at end of file +export default ImageUrlInput; diff --git a/web/src/components/playground/MessageActions.jsx b/web/src/components/playground/MessageActions.jsx index 64775ae52..093700367 100644 --- a/web/src/components/playground/MessageActions.jsx +++ b/web/src/components/playground/MessageActions.jsx @@ -18,17 +18,8 @@ For commercial licensing, please contact support@quantumnous.com */ import React from 'react'; -import { - Button, - Tooltip, -} from '@douyinfe/semi-ui'; -import { - RefreshCw, - Copy, - Trash2, - UserCheck, - Edit, -} from 'lucide-react'; +import { Button, Tooltip } from '@douyinfe/semi-ui'; +import { RefreshCw, Copy, Trash2, UserCheck, Edit } from 'lucide-react'; import { useTranslation } from 'react-i18next'; const MessageActions = ({ @@ -40,23 +31,32 @@ const MessageActions = ({ onRoleToggle, onMessageEdit, isAnyMessageGenerating = false, - isEditing = false + isEditing = false, }) => { const { t } = useTranslation(); - const isLoading = message.status === 'loading' || message.status === 'incomplete'; + const isLoading = + message.status === 'loading' || message.status === 'incomplete'; const shouldDisableActions = isAnyMessageGenerating || isEditing; - const canToggleRole = message.role === 'assistant' || message.role === 'system'; - const canEdit = !isLoading && message.content && typeof onMessageEdit === 'function' && !isEditing; + const canToggleRole = + message.role === 'assistant' || message.role === 'system'; + const canEdit = + !isLoading && + message.content && + typeof onMessageEdit === 'function' && + !isEditing; return ( -
          +
          {!isLoading && ( - +