diff --git a/controller/channel.go b/controller/channel.go index 9fcc95e0b..afc6600e7 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -209,157 +209,14 @@ func FetchUpstreamModels(c *gin.Context) { return } - baseURL := constant.ChannelBaseURLs[channel.Type] - if channel.GetBaseURL() != "" { - baseURL = channel.GetBaseURL() - } - - // 对于 Ollama 渠道,使用特殊处理 - if channel.Type == constant.ChannelTypeOllama { - key := strings.Split(channel.Key, "\n")[0] - models, err := ollama.FetchOllamaModels(baseURL, key) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("获取Ollama模型失败: %s", err.Error()), - }) - return - } - - result := OpenAIModelsResponse{ - Data: make([]OpenAIModel, 0, len(models)), - } - - for _, modelInfo := range models { - metadata := map[string]any{} - if modelInfo.Size > 0 { - metadata["size"] = modelInfo.Size - } - if modelInfo.Digest != "" { - metadata["digest"] = modelInfo.Digest - } - if modelInfo.ModifiedAt != "" { - metadata["modified_at"] = modelInfo.ModifiedAt - } - details := modelInfo.Details - if details.ParentModel != "" || details.Format != "" || details.Family != "" || len(details.Families) > 0 || details.ParameterSize != "" || details.QuantizationLevel != "" { - metadata["details"] = modelInfo.Details - } - if len(metadata) == 0 { - metadata = nil - } - - result.Data = append(result.Data, OpenAIModel{ - ID: modelInfo.Name, - Object: "model", - Created: 0, - OwnedBy: "ollama", - Metadata: metadata, - }) - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": result.Data, - }) - return - } - - // 对于 Gemini 渠道,使用特殊处理 - if channel.Type == constant.ChannelTypeGemini { - // 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥) - key, _, apiErr := channel.GetNextEnabledKey() - if apiErr != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()), - }) - return - } - key = strings.TrimSpace(key) - models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy) - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("获取Gemini模型失败: %s", err.Error()), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": models, - }) - return - } - - var url string - switch channel.Type { - case constant.ChannelTypeAli: - url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) - case constant.ChannelTypeZhipu_v4: - if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" { - url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL) - } else { - url = fmt.Sprintf("%s/api/paas/v4/models", baseURL) - } - case constant.ChannelTypeVolcEngine: - if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" { - url = fmt.Sprintf("%s/v1/models", plan.OpenAIBaseURL) - } else { - url = fmt.Sprintf("%s/v1/models", baseURL) - } - case constant.ChannelTypeMoonshot: - if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" { - url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL) - } else { - url = fmt.Sprintf("%s/v1/models", baseURL) - } - default: - url = fmt.Sprintf("%s/v1/models", baseURL) - } - - // 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥) - key, _, apiErr := channel.GetNextEnabledKey() - if apiErr != nil { + ids, err := fetchChannelUpstreamModelIDs(channel) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()), + "message": fmt.Sprintf("获取模型列表失败: %s", err.Error()), }) return } - key = strings.TrimSpace(key) - - headers, err := buildFetchModelsHeaders(channel, key) - if err != nil { - common.ApiError(c, err) - return - } - - body, err := GetResponseBody("GET", url, channel, headers) - if err != nil { - common.ApiError(c, err) - return - } - - var result OpenAIModelsResponse - if err = json.Unmarshal(body, &result); err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("解析响应失败: %s", err.Error()), - }) - return - } - - var ids []string - for _, model := range result.Data { - id := model.ID - if channel.Type == constant.ChannelTypeGemini { - id = strings.TrimPrefix(id, "models/") - } - ids = append(ids, id) - } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/controller/channel_upstream_update.go b/controller/channel_upstream_update.go new file mode 100644 index 000000000..701bef78b --- /dev/null +++ b/controller/channel_upstream_update.go @@ -0,0 +1,983 @@ +package controller + +import ( + "fmt" + "net/http" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/constant" + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/model" + "github.com/QuantumNous/new-api/relay/channel/gemini" + "github.com/QuantumNous/new-api/relay/channel/ollama" + "github.com/QuantumNous/new-api/service" + + "github.com/gin-gonic/gin" + "github.com/samber/lo" +) + +const ( + channelUpstreamModelUpdateTaskDefaultIntervalMinutes = 30 + channelUpstreamModelUpdateTaskBatchSize = 100 + channelUpstreamModelUpdateMinCheckIntervalSeconds = 300 + channelUpstreamModelUpdateNotifySuppressWindowSeconds = 86400 + channelUpstreamModelUpdateNotifyMaxChannelDetails = 8 + channelUpstreamModelUpdateNotifyMaxModelDetails = 12 + channelUpstreamModelUpdateNotifyMaxFailedChannelIDs = 10 +) + +var ( + channelUpstreamModelUpdateTaskOnce sync.Once + channelUpstreamModelUpdateTaskRunning atomic.Bool + channelUpstreamModelUpdateNotifyState = struct { + sync.Mutex + lastNotifiedAt int64 + lastChangedChannels int + lastFailedChannels int + }{} +) + +type applyChannelUpstreamModelUpdatesRequest struct { + ID int `json:"id"` + AddModels []string `json:"add_models"` + RemoveModels []string `json:"remove_models"` + IgnoreModels []string `json:"ignore_models"` +} + +type applyAllChannelUpstreamModelUpdatesResult struct { + ChannelID int `json:"channel_id"` + ChannelName string `json:"channel_name"` + AddedModels []string `json:"added_models"` + RemovedModels []string `json:"removed_models"` + RemainingModels []string `json:"remaining_models"` + RemainingRemoveModels []string `json:"remaining_remove_models"` +} + +type detectChannelUpstreamModelUpdatesResult struct { + ChannelID int `json:"channel_id"` + ChannelName string `json:"channel_name"` + AddModels []string `json:"add_models"` + RemoveModels []string `json:"remove_models"` + LastCheckTime int64 `json:"last_check_time"` + AutoAddedModels int `json:"auto_added_models"` +} + +type upstreamModelUpdateChannelSummary struct { + ChannelName string + AddCount int + RemoveCount int +} + +func normalizeModelNames(models []string) []string { + return lo.Uniq(lo.FilterMap(models, func(model string, _ int) (string, bool) { + trimmed := strings.TrimSpace(model) + return trimmed, trimmed != "" + })) +} + +func mergeModelNames(base []string, appended []string) []string { + merged := normalizeModelNames(base) + seen := make(map[string]struct{}, len(merged)) + for _, model := range merged { + seen[model] = struct{}{} + } + for _, model := range normalizeModelNames(appended) { + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + merged = append(merged, model) + } + return merged +} + +func subtractModelNames(base []string, removed []string) []string { + removeSet := make(map[string]struct{}, len(removed)) + for _, model := range normalizeModelNames(removed) { + removeSet[model] = struct{}{} + } + return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool { + _, ok := removeSet[model] + return !ok + }) +} + +func intersectModelNames(base []string, allowed []string) []string { + allowedSet := make(map[string]struct{}, len(allowed)) + for _, model := range normalizeModelNames(allowed) { + allowedSet[model] = struct{}{} + } + return lo.Filter(normalizeModelNames(base), func(model string, _ int) bool { + _, ok := allowedSet[model] + return ok + }) +} + +func applySelectedModelChanges(originModels []string, addModels []string, removeModels []string) []string { + // Add wins when the same model appears in both selected lists. + normalizedAdd := normalizeModelNames(addModels) + normalizedRemove := subtractModelNames(normalizeModelNames(removeModels), normalizedAdd) + return subtractModelNames(mergeModelNames(originModels, normalizedAdd), normalizedRemove) +} + +func normalizeChannelModelMapping(channel *model.Channel) map[string]string { + if channel == nil || channel.ModelMapping == nil { + return nil + } + rawMapping := strings.TrimSpace(*channel.ModelMapping) + if rawMapping == "" || rawMapping == "{}" { + return nil + } + parsed := make(map[string]string) + if err := common.UnmarshalJsonStr(rawMapping, &parsed); err != nil { + return nil + } + normalized := make(map[string]string, len(parsed)) + for source, target := range parsed { + normalizedSource := strings.TrimSpace(source) + normalizedTarget := strings.TrimSpace(target) + if normalizedSource == "" || normalizedTarget == "" { + continue + } + normalized[normalizedSource] = normalizedTarget + } + if len(normalized) == 0 { + return nil + } + return normalized +} + +func collectPendingUpstreamModelChangesFromModels( + localModels []string, + upstreamModels []string, + ignoredModels []string, + modelMapping map[string]string, +) (pendingAddModels []string, pendingRemoveModels []string) { + localSet := make(map[string]struct{}) + localModels = normalizeModelNames(localModels) + upstreamModels = normalizeModelNames(upstreamModels) + for _, modelName := range localModels { + localSet[modelName] = struct{}{} + } + upstreamSet := make(map[string]struct{}, len(upstreamModels)) + for _, modelName := range upstreamModels { + upstreamSet[modelName] = struct{}{} + } + + ignoredSet := make(map[string]struct{}) + for _, modelName := range normalizeModelNames(ignoredModels) { + ignoredSet[modelName] = struct{}{} + } + + redirectSourceSet := make(map[string]struct{}, len(modelMapping)) + redirectTargetSet := make(map[string]struct{}, len(modelMapping)) + for source, target := range modelMapping { + redirectSourceSet[source] = struct{}{} + redirectTargetSet[target] = struct{}{} + } + + coveredUpstreamSet := make(map[string]struct{}, len(localSet)+len(redirectTargetSet)) + for modelName := range localSet { + coveredUpstreamSet[modelName] = struct{}{} + } + for modelName := range redirectTargetSet { + coveredUpstreamSet[modelName] = struct{}{} + } + + pendingAdd := lo.Filter(upstreamModels, func(modelName string, _ int) bool { + if _, ok := coveredUpstreamSet[modelName]; ok { + return false + } + if _, ok := ignoredSet[modelName]; ok { + return false + } + return true + }) + pendingRemove := lo.Filter(localModels, func(modelName string, _ int) bool { + // Redirect source models are virtual aliases and should not be removed + // only because they are absent from upstream model list. + if _, ok := redirectSourceSet[modelName]; ok { + return false + } + _, ok := upstreamSet[modelName] + return !ok + }) + return normalizeModelNames(pendingAdd), normalizeModelNames(pendingRemove) +} + +func collectPendingUpstreamModelChanges(channel *model.Channel, settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string, err error) { + upstreamModels, err := fetchChannelUpstreamModelIDs(channel) + if err != nil { + return nil, nil, err + } + pendingAddModels, pendingRemoveModels = collectPendingUpstreamModelChangesFromModels( + channel.GetModels(), + upstreamModels, + settings.UpstreamModelUpdateIgnoredModels, + normalizeChannelModelMapping(channel), + ) + return pendingAddModels, pendingRemoveModels, nil +} + +func getUpstreamModelUpdateMinCheckIntervalSeconds() int64 { + interval := int64(common.GetEnvOrDefault( + "CHANNEL_UPSTREAM_MODEL_UPDATE_MIN_CHECK_INTERVAL_SECONDS", + channelUpstreamModelUpdateMinCheckIntervalSeconds, + )) + if interval < 0 { + return channelUpstreamModelUpdateMinCheckIntervalSeconds + } + return interval +} + +func fetchChannelUpstreamModelIDs(channel *model.Channel) ([]string, error) { + baseURL := constant.ChannelBaseURLs[channel.Type] + if channel.GetBaseURL() != "" { + baseURL = channel.GetBaseURL() + } + + if channel.Type == constant.ChannelTypeOllama { + key := strings.TrimSpace(strings.Split(channel.Key, "\n")[0]) + models, err := ollama.FetchOllamaModels(baseURL, key) + if err != nil { + return nil, err + } + return normalizeModelNames(lo.Map(models, func(item ollama.OllamaModel, _ int) string { + return item.Name + })), nil + } + + if channel.Type == constant.ChannelTypeGemini { + key, _, apiErr := channel.GetNextEnabledKey() + if apiErr != nil { + return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr) + } + key = strings.TrimSpace(key) + models, err := gemini.FetchGeminiModels(baseURL, key, channel.GetSetting().Proxy) + if err != nil { + return nil, err + } + return normalizeModelNames(models), nil + } + + var url string + switch channel.Type { + case constant.ChannelTypeAli: + url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) + case constant.ChannelTypeZhipu_v4: + if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" { + url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL) + } else { + url = fmt.Sprintf("%s/api/paas/v4/models", baseURL) + } + case constant.ChannelTypeVolcEngine: + if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" { + url = fmt.Sprintf("%s/v1/models", plan.OpenAIBaseURL) + } else { + url = fmt.Sprintf("%s/v1/models", baseURL) + } + case constant.ChannelTypeMoonshot: + if plan, ok := constant.ChannelSpecialBases[baseURL]; ok && plan.OpenAIBaseURL != "" { + url = fmt.Sprintf("%s/models", plan.OpenAIBaseURL) + } else { + url = fmt.Sprintf("%s/v1/models", baseURL) + } + default: + url = fmt.Sprintf("%s/v1/models", baseURL) + } + + key, _, apiErr := channel.GetNextEnabledKey() + if apiErr != nil { + return nil, fmt.Errorf("获取渠道密钥失败: %w", apiErr) + } + key = strings.TrimSpace(key) + + headers, err := buildFetchModelsHeaders(channel, key) + if err != nil { + return nil, err + } + + body, err := GetResponseBody(http.MethodGet, url, channel, headers) + if err != nil { + return nil, err + } + + var result OpenAIModelsResponse + if err := common.Unmarshal(body, &result); err != nil { + return nil, err + } + + ids := lo.Map(result.Data, func(item OpenAIModel, _ int) string { + if channel.Type == constant.ChannelTypeGemini { + return strings.TrimPrefix(item.ID, "models/") + } + return item.ID + }) + + return normalizeModelNames(ids), nil +} + +func updateChannelUpstreamModelSettings(channel *model.Channel, settings dto.ChannelOtherSettings, updateModels bool) error { + channel.SetOtherSettings(settings) + updates := map[string]interface{}{ + "settings": channel.OtherSettings, + } + if updateModels { + updates["models"] = channel.Models + } + return model.DB.Model(&model.Channel{}).Where("id = ?", channel.Id).Updates(updates).Error +} + +func checkAndPersistChannelUpstreamModelUpdates( + channel *model.Channel, + settings *dto.ChannelOtherSettings, + force bool, + allowAutoApply bool, +) (modelsChanged bool, autoAdded int, err error) { + now := common.GetTimestamp() + if !force { + minInterval := getUpstreamModelUpdateMinCheckIntervalSeconds() + if settings.UpstreamModelUpdateLastCheckTime > 0 && + now-settings.UpstreamModelUpdateLastCheckTime < minInterval { + return false, 0, nil + } + } + + pendingAddModels, pendingRemoveModels, fetchErr := collectPendingUpstreamModelChanges(channel, *settings) + settings.UpstreamModelUpdateLastCheckTime = now + if fetchErr != nil { + if err = updateChannelUpstreamModelSettings(channel, *settings, false); err != nil { + return false, 0, err + } + return false, 0, fetchErr + } + + if allowAutoApply && settings.UpstreamModelUpdateAutoSyncEnabled && len(pendingAddModels) > 0 { + originModels := normalizeModelNames(channel.GetModels()) + mergedModels := mergeModelNames(originModels, pendingAddModels) + if len(mergedModels) > len(originModels) { + channel.Models = strings.Join(mergedModels, ",") + autoAdded = len(mergedModels) - len(originModels) + modelsChanged = true + } + settings.UpstreamModelUpdateLastDetectedModels = []string{} + } else { + settings.UpstreamModelUpdateLastDetectedModels = pendingAddModels + } + settings.UpstreamModelUpdateLastRemovedModels = pendingRemoveModels + + if err = updateChannelUpstreamModelSettings(channel, *settings, modelsChanged); err != nil { + return false, autoAdded, err + } + if modelsChanged { + if err = channel.UpdateAbilities(nil); err != nil { + return true, autoAdded, err + } + } + return modelsChanged, autoAdded, nil +} + +func refreshChannelRuntimeCache() { + if common.MemoryCacheEnabled { + func() { + defer func() { + if r := recover(); r != nil { + common.SysLog(fmt.Sprintf("InitChannelCache panic: %v", r)) + } + }() + model.InitChannelCache() + }() + } + service.ResetProxyClientCache() +} + +func shouldSendUpstreamModelUpdateNotification(now int64, changedChannels int, failedChannels int) bool { + if changedChannels <= 0 && failedChannels <= 0 { + return true + } + + channelUpstreamModelUpdateNotifyState.Lock() + defer channelUpstreamModelUpdateNotifyState.Unlock() + + if channelUpstreamModelUpdateNotifyState.lastNotifiedAt > 0 && + now-channelUpstreamModelUpdateNotifyState.lastNotifiedAt < channelUpstreamModelUpdateNotifySuppressWindowSeconds && + channelUpstreamModelUpdateNotifyState.lastChangedChannels == changedChannels && + channelUpstreamModelUpdateNotifyState.lastFailedChannels == failedChannels { + return false + } + + channelUpstreamModelUpdateNotifyState.lastNotifiedAt = now + channelUpstreamModelUpdateNotifyState.lastChangedChannels = changedChannels + channelUpstreamModelUpdateNotifyState.lastFailedChannels = failedChannels + return true +} + +func buildUpstreamModelUpdateTaskNotificationContent( + checkedChannels int, + changedChannels int, + detectedAddModels int, + detectedRemoveModels int, + autoAddedModels int, + failedChannelIDs []int, + channelSummaries []upstreamModelUpdateChannelSummary, + addModelSamples []string, + removeModelSamples []string, +) string { + var builder strings.Builder + failedChannels := len(failedChannelIDs) + builder.WriteString(fmt.Sprintf( + "上游模型巡检摘要:检测渠道 %d 个,发现变更 %d 个,新增 %d 个,删除 %d 个,自动同步新增 %d 个,失败 %d 个。", + checkedChannels, + changedChannels, + detectedAddModels, + detectedRemoveModels, + autoAddedModels, + failedChannels, + )) + + if len(channelSummaries) > 0 { + displayCount := min(len(channelSummaries), channelUpstreamModelUpdateNotifyMaxChannelDetails) + builder.WriteString(fmt.Sprintf("\n\n变更渠道明细(展示 %d/%d):", displayCount, len(channelSummaries))) + for _, summary := range channelSummaries[:displayCount] { + builder.WriteString(fmt.Sprintf("\n- %s (+%d / -%d)", summary.ChannelName, summary.AddCount, summary.RemoveCount)) + } + if len(channelSummaries) > displayCount { + builder.WriteString(fmt.Sprintf("\n- 其余 %d 个渠道已省略", len(channelSummaries)-displayCount)) + } + } + + normalizedAddModelSamples := normalizeModelNames(addModelSamples) + if len(normalizedAddModelSamples) > 0 { + displayCount := min(len(normalizedAddModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails) + builder.WriteString(fmt.Sprintf("\n\n新增模型示例(展示 %d/%d):%s", + displayCount, + len(normalizedAddModelSamples), + strings.Join(normalizedAddModelSamples[:displayCount], ", "), + )) + if len(normalizedAddModelSamples) > displayCount { + builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedAddModelSamples)-displayCount)) + } + } + + normalizedRemoveModelSamples := normalizeModelNames(removeModelSamples) + if len(normalizedRemoveModelSamples) > 0 { + displayCount := min(len(normalizedRemoveModelSamples), channelUpstreamModelUpdateNotifyMaxModelDetails) + builder.WriteString(fmt.Sprintf("\n\n删除模型示例(展示 %d/%d):%s", + displayCount, + len(normalizedRemoveModelSamples), + strings.Join(normalizedRemoveModelSamples[:displayCount], ", "), + )) + if len(normalizedRemoveModelSamples) > displayCount { + builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", len(normalizedRemoveModelSamples)-displayCount)) + } + } + + if failedChannels > 0 { + displayCount := min(failedChannels, channelUpstreamModelUpdateNotifyMaxFailedChannelIDs) + displayIDs := lo.Map(failedChannelIDs[:displayCount], func(channelID int, _ int) string { + return fmt.Sprintf("%d", channelID) + }) + builder.WriteString(fmt.Sprintf( + "\n\n失败渠道 ID(展示 %d/%d):%s", + displayCount, + failedChannels, + strings.Join(displayIDs, ", "), + )) + if failedChannels > displayCount { + builder.WriteString(fmt.Sprintf("(其余 %d 个已省略)", failedChannels-displayCount)) + } + } + return builder.String() +} + +func runChannelUpstreamModelUpdateTaskOnce() { + if !channelUpstreamModelUpdateTaskRunning.CompareAndSwap(false, true) { + return + } + defer channelUpstreamModelUpdateTaskRunning.Store(false) + + checkedChannels := 0 + failedChannels := 0 + failedChannelIDs := make([]int, 0) + changedChannels := 0 + detectedAddModels := 0 + detectedRemoveModels := 0 + autoAddedModels := 0 + channelSummaries := make([]upstreamModelUpdateChannelSummary, 0) + addModelSamples := make([]string, 0) + removeModelSamples := make([]string, 0) + refreshNeeded := false + + lastID := 0 + for { + var channels []*model.Channel + query := model.DB. + Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override"). + Where("status = ?", common.ChannelStatusEnabled). + Order("id asc"). + Limit(channelUpstreamModelUpdateTaskBatchSize) + if lastID > 0 { + query = query.Where("id > ?", lastID) + } + err := query.Find(&channels).Error + if err != nil { + common.SysLog(fmt.Sprintf("upstream model update task query failed: %v", err)) + break + } + if len(channels) == 0 { + break + } + lastID = channels[len(channels)-1].Id + + for _, channel := range channels { + if channel == nil { + continue + } + + settings := channel.GetOtherSettings() + if !settings.UpstreamModelUpdateCheckEnabled { + continue + } + + checkedChannels++ + modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, false, true) + if err != nil { + failedChannels++ + failedChannelIDs = append(failedChannelIDs, channel.Id) + common.SysLog(fmt.Sprintf("upstream model update check failed: channel_id=%d channel_name=%s err=%v", channel.Id, channel.Name, err)) + continue + } + currentAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels) + currentRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels) + currentAddCount := len(currentAddModels) + autoAdded + currentRemoveCount := len(currentRemoveModels) + detectedAddModels += currentAddCount + detectedRemoveModels += currentRemoveCount + if currentAddCount > 0 || currentRemoveCount > 0 { + changedChannels++ + channelSummaries = append(channelSummaries, upstreamModelUpdateChannelSummary{ + ChannelName: channel.Name, + AddCount: currentAddCount, + RemoveCount: currentRemoveCount, + }) + } + addModelSamples = mergeModelNames(addModelSamples, currentAddModels) + removeModelSamples = mergeModelNames(removeModelSamples, currentRemoveModels) + if modelsChanged { + refreshNeeded = true + } + autoAddedModels += autoAdded + + if common.RequestInterval > 0 { + time.Sleep(common.RequestInterval) + } + } + + if len(channels) < channelUpstreamModelUpdateTaskBatchSize { + break + } + } + + if refreshNeeded { + refreshChannelRuntimeCache() + } + + if checkedChannels > 0 || common.DebugEnabled { + common.SysLog(fmt.Sprintf( + "upstream model update task done: checked_channels=%d changed_channels=%d detected_add_models=%d detected_remove_models=%d failed_channels=%d auto_added_models=%d", + checkedChannels, + changedChannels, + detectedAddModels, + detectedRemoveModels, + failedChannels, + autoAddedModels, + )) + } + if changedChannels > 0 || failedChannels > 0 { + now := common.GetTimestamp() + if !shouldSendUpstreamModelUpdateNotification(now, changedChannels, failedChannels) { + common.SysLog(fmt.Sprintf( + "upstream model update notification skipped in 24h window: changed_channels=%d failed_channels=%d", + changedChannels, + failedChannels, + )) + return + } + service.NotifyUpstreamModelUpdateWatchers( + "上游模型巡检通知", + buildUpstreamModelUpdateTaskNotificationContent( + checkedChannels, + changedChannels, + detectedAddModels, + detectedRemoveModels, + autoAddedModels, + failedChannelIDs, + channelSummaries, + addModelSamples, + removeModelSamples, + ), + ) + } +} + +func StartChannelUpstreamModelUpdateTask() { + channelUpstreamModelUpdateTaskOnce.Do(func() { + if !common.IsMasterNode { + return + } + if !common.GetEnvOrDefaultBool("CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED", true) { + common.SysLog("upstream model update task disabled by CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_ENABLED") + return + } + + intervalMinutes := common.GetEnvOrDefault( + "CHANNEL_UPSTREAM_MODEL_UPDATE_TASK_INTERVAL_MINUTES", + channelUpstreamModelUpdateTaskDefaultIntervalMinutes, + ) + if intervalMinutes < 1 { + intervalMinutes = channelUpstreamModelUpdateTaskDefaultIntervalMinutes + } + interval := time.Duration(intervalMinutes) * time.Minute + + go func() { + common.SysLog(fmt.Sprintf("upstream model update task started: interval=%s", interval)) + runChannelUpstreamModelUpdateTaskOnce() + ticker := time.NewTicker(interval) + defer ticker.Stop() + for range ticker.C { + runChannelUpstreamModelUpdateTaskOnce() + } + }() + }) +} + +func ApplyChannelUpstreamModelUpdates(c *gin.Context) { + var req applyChannelUpstreamModelUpdatesRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiError(c, err) + return + } + if req.ID <= 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "invalid channel id", + }) + return + } + + channel, err := model.GetChannelById(req.ID, true) + if err != nil { + common.ApiError(c, err) + return + } + beforeSettings := channel.GetOtherSettings() + ignoredModels := intersectModelNames(req.IgnoreModels, beforeSettings.UpstreamModelUpdateLastDetectedModels) + + addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates( + channel, + req.AddModels, + req.IgnoreModels, + req.RemoveModels, + ) + if err != nil { + common.ApiError(c, err) + return + } + + if modelsChanged { + refreshChannelRuntimeCache() + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "id": channel.Id, + "added_models": addedModels, + "removed_models": removedModels, + "ignored_models": ignoredModels, + "remaining_models": remainingModels, + "remaining_remove_models": remainingRemoveModels, + "models": channel.Models, + "settings": channel.OtherSettings, + }, + }) +} + +func DetectChannelUpstreamModelUpdates(c *gin.Context) { + var req applyChannelUpstreamModelUpdatesRequest + if err := c.ShouldBindJSON(&req); err != nil { + common.ApiError(c, err) + return + } + if req.ID <= 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "invalid channel id", + }) + return + } + + channel, err := model.GetChannelById(req.ID, true) + if err != nil { + common.ApiError(c, err) + return + } + + settings := channel.GetOtherSettings() + if !settings.UpstreamModelUpdateCheckEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该渠道未开启上游模型更新检测", + }) + return + } + + modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false) + if err != nil { + common.ApiError(c, err) + return + } + if modelsChanged { + refreshChannelRuntimeCache() + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": detectChannelUpstreamModelUpdatesResult{ + ChannelID: channel.Id, + ChannelName: channel.Name, + AddModels: normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels), + RemoveModels: normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels), + LastCheckTime: settings.UpstreamModelUpdateLastCheckTime, + AutoAddedModels: autoAdded, + }, + }) +} + +func applyChannelUpstreamModelUpdates( + channel *model.Channel, + addModelsInput []string, + ignoreModelsInput []string, + removeModelsInput []string, +) ( + addedModels []string, + removedModels []string, + remainingModels []string, + remainingRemoveModels []string, + modelsChanged bool, + err error, +) { + settings := channel.GetOtherSettings() + pendingAddModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels) + pendingRemoveModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels) + addModels := intersectModelNames(addModelsInput, pendingAddModels) + ignoreModels := intersectModelNames(ignoreModelsInput, pendingAddModels) + removeModels := intersectModelNames(removeModelsInput, pendingRemoveModels) + removeModels = subtractModelNames(removeModels, addModels) + + originModels := normalizeModelNames(channel.GetModels()) + nextModels := applySelectedModelChanges(originModels, addModels, removeModels) + modelsChanged = !slices.Equal(originModels, nextModels) + if modelsChanged { + channel.Models = strings.Join(nextModels, ",") + } + + settings.UpstreamModelUpdateIgnoredModels = mergeModelNames(settings.UpstreamModelUpdateIgnoredModels, ignoreModels) + if len(addModels) > 0 { + settings.UpstreamModelUpdateIgnoredModels = subtractModelNames(settings.UpstreamModelUpdateIgnoredModels, addModels) + } + remainingModels = subtractModelNames(pendingAddModels, append(addModels, ignoreModels...)) + remainingRemoveModels = subtractModelNames(pendingRemoveModels, removeModels) + settings.UpstreamModelUpdateLastDetectedModels = remainingModels + settings.UpstreamModelUpdateLastRemovedModels = remainingRemoveModels + settings.UpstreamModelUpdateLastCheckTime = common.GetTimestamp() + + if err := updateChannelUpstreamModelSettings(channel, settings, modelsChanged); err != nil { + return nil, nil, nil, nil, false, err + } + + if modelsChanged { + if err := channel.UpdateAbilities(nil); err != nil { + return addModels, removeModels, remainingModels, remainingRemoveModels, true, err + } + } + return addModels, removeModels, remainingModels, remainingRemoveModels, modelsChanged, nil +} + +func collectPendingApplyUpstreamModelChanges(settings dto.ChannelOtherSettings) (pendingAddModels []string, pendingRemoveModels []string) { + return normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels), normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels) +} + +func findEnabledChannelsAfterID(lastID int, batchSize int) ([]*model.Channel, error) { + var channels []*model.Channel + query := model.DB. + Select("id", "name", "type", "key", "status", "base_url", "models", "settings", "setting", "other", "group", "priority", "weight", "tag", "channel_info", "header_override"). + Where("status = ?", common.ChannelStatusEnabled). + Order("id asc"). + Limit(batchSize) + if lastID > 0 { + query = query.Where("id > ?", lastID) + } + return channels, query.Find(&channels).Error +} + +func ApplyAllChannelUpstreamModelUpdates(c *gin.Context) { + results := make([]applyAllChannelUpstreamModelUpdatesResult, 0) + failed := make([]int, 0) + refreshNeeded := false + addedModelCount := 0 + removedModelCount := 0 + + lastID := 0 + for { + channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize) + if err != nil { + common.ApiError(c, err) + return + } + if len(channels) == 0 { + break + } + lastID = channels[len(channels)-1].Id + + for _, channel := range channels { + if channel == nil { + continue + } + + settings := channel.GetOtherSettings() + if !settings.UpstreamModelUpdateCheckEnabled { + continue + } + + pendingAddModels, pendingRemoveModels := collectPendingApplyUpstreamModelChanges(settings) + if len(pendingAddModels) == 0 && len(pendingRemoveModels) == 0 { + continue + } + + addedModels, removedModels, remainingModels, remainingRemoveModels, modelsChanged, err := applyChannelUpstreamModelUpdates( + channel, + pendingAddModels, + nil, + pendingRemoveModels, + ) + if err != nil { + failed = append(failed, channel.Id) + continue + } + if modelsChanged { + refreshNeeded = true + } + addedModelCount += len(addedModels) + removedModelCount += len(removedModels) + results = append(results, applyAllChannelUpstreamModelUpdatesResult{ + ChannelID: channel.Id, + ChannelName: channel.Name, + AddedModels: addedModels, + RemovedModels: removedModels, + RemainingModels: remainingModels, + RemainingRemoveModels: remainingRemoveModels, + }) + } + + if len(channels) < channelUpstreamModelUpdateTaskBatchSize { + break + } + } + + if refreshNeeded { + refreshChannelRuntimeCache() + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "processed_channels": len(results), + "added_models": addedModelCount, + "removed_models": removedModelCount, + "failed_channel_ids": failed, + "results": results, + }, + }) +} + +func DetectAllChannelUpstreamModelUpdates(c *gin.Context) { + results := make([]detectChannelUpstreamModelUpdatesResult, 0) + failed := make([]int, 0) + detectedAddCount := 0 + detectedRemoveCount := 0 + refreshNeeded := false + + lastID := 0 + for { + channels, err := findEnabledChannelsAfterID(lastID, channelUpstreamModelUpdateTaskBatchSize) + if err != nil { + common.ApiError(c, err) + return + } + if len(channels) == 0 { + break + } + lastID = channels[len(channels)-1].Id + + for _, channel := range channels { + if channel == nil { + continue + } + settings := channel.GetOtherSettings() + if !settings.UpstreamModelUpdateCheckEnabled { + continue + } + + modelsChanged, autoAdded, err := checkAndPersistChannelUpstreamModelUpdates(channel, &settings, true, false) + if err != nil { + failed = append(failed, channel.Id) + continue + } + if modelsChanged { + refreshNeeded = true + } + + addModels := normalizeModelNames(settings.UpstreamModelUpdateLastDetectedModels) + removeModels := normalizeModelNames(settings.UpstreamModelUpdateLastRemovedModels) + detectedAddCount += len(addModels) + detectedRemoveCount += len(removeModels) + results = append(results, detectChannelUpstreamModelUpdatesResult{ + ChannelID: channel.Id, + ChannelName: channel.Name, + AddModels: addModels, + RemoveModels: removeModels, + LastCheckTime: settings.UpstreamModelUpdateLastCheckTime, + AutoAddedModels: autoAdded, + }) + } + + if len(channels) < channelUpstreamModelUpdateTaskBatchSize { + break + } + } + + if refreshNeeded { + refreshChannelRuntimeCache() + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": gin.H{ + "processed_channels": len(results), + "failed_channel_ids": failed, + "detected_add_models": detectedAddCount, + "detected_remove_models": detectedRemoveCount, + "channel_detected_results": results, + }, + }) +} diff --git a/controller/channel_upstream_update_test.go b/controller/channel_upstream_update_test.go new file mode 100644 index 000000000..153119d41 --- /dev/null +++ b/controller/channel_upstream_update_test.go @@ -0,0 +1,167 @@ +package controller + +import ( + "testing" + + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/model" + "github.com/stretchr/testify/require" +) + +func TestNormalizeModelNames(t *testing.T) { + result := normalizeModelNames([]string{ + " gpt-4o ", + "", + "gpt-4o", + "gpt-4.1", + " ", + }) + + require.Equal(t, []string{"gpt-4o", "gpt-4.1"}, result) +} + +func TestMergeModelNames(t *testing.T) { + result := mergeModelNames( + []string{"gpt-4o", "gpt-4.1"}, + []string{"gpt-4.1", " gpt-4.1-mini ", "gpt-4o"}, + ) + + require.Equal(t, []string{"gpt-4o", "gpt-4.1", "gpt-4.1-mini"}, result) +} + +func TestSubtractModelNames(t *testing.T) { + result := subtractModelNames( + []string{"gpt-4o", "gpt-4.1", "gpt-4.1-mini"}, + []string{"gpt-4.1", "not-exists"}, + ) + + require.Equal(t, []string{"gpt-4o", "gpt-4.1-mini"}, result) +} + +func TestIntersectModelNames(t *testing.T) { + result := intersectModelNames( + []string{"gpt-4o", "gpt-4.1", "gpt-4.1", "not-exists"}, + []string{"gpt-4.1", "gpt-4o-mini", "gpt-4o"}, + ) + + require.Equal(t, []string{"gpt-4o", "gpt-4.1"}, result) +} + +func TestApplySelectedModelChanges(t *testing.T) { + t.Run("add and remove together", func(t *testing.T) { + result := applySelectedModelChanges( + []string{"gpt-4o", "gpt-4.1", "claude-3"}, + []string{"gpt-4.1-mini"}, + []string{"claude-3"}, + ) + + require.Equal(t, []string{"gpt-4o", "gpt-4.1", "gpt-4.1-mini"}, result) + }) + + t.Run("add wins when conflict with remove", func(t *testing.T) { + result := applySelectedModelChanges( + []string{"gpt-4o"}, + []string{"gpt-4.1"}, + []string{"gpt-4.1"}, + ) + + require.Equal(t, []string{"gpt-4o", "gpt-4.1"}, result) + }) +} + +func TestCollectPendingApplyUpstreamModelChanges(t *testing.T) { + settings := dto.ChannelOtherSettings{ + UpstreamModelUpdateLastDetectedModels: []string{" gpt-4o ", "gpt-4o", "gpt-4.1"}, + UpstreamModelUpdateLastRemovedModels: []string{" old-model ", "", "old-model"}, + } + + pendingAddModels, pendingRemoveModels := collectPendingApplyUpstreamModelChanges(settings) + + require.Equal(t, []string{"gpt-4o", "gpt-4.1"}, pendingAddModels) + require.Equal(t, []string{"old-model"}, pendingRemoveModels) +} + +func TestNormalizeChannelModelMapping(t *testing.T) { + modelMapping := `{ + " alias-model ": " upstream-model ", + "": "invalid", + "invalid-target": "" + }` + channel := &model.Channel{ + ModelMapping: &modelMapping, + } + + result := normalizeChannelModelMapping(channel) + require.Equal(t, map[string]string{ + "alias-model": "upstream-model", + }, result) +} + +func TestCollectPendingUpstreamModelChangesFromModels_WithModelMapping(t *testing.T) { + pendingAddModels, pendingRemoveModels := collectPendingUpstreamModelChangesFromModels( + []string{"alias-model", "gpt-4o", "stale-model"}, + []string{"gpt-4o", "gpt-4.1", "mapped-target"}, + []string{"gpt-4.1"}, + map[string]string{ + "alias-model": "mapped-target", + }, + ) + + require.Equal(t, []string{}, pendingAddModels) + require.Equal(t, []string{"stale-model"}, pendingRemoveModels) +} + +func TestBuildUpstreamModelUpdateTaskNotificationContent_OmitOverflowDetails(t *testing.T) { + channelSummaries := make([]upstreamModelUpdateChannelSummary, 0, 12) + for i := 0; i < 12; i++ { + channelSummaries = append(channelSummaries, upstreamModelUpdateChannelSummary{ + ChannelName: "channel-" + string(rune('A'+i)), + AddCount: i + 1, + RemoveCount: i, + }) + } + + content := buildUpstreamModelUpdateTaskNotificationContent( + 24, + 12, + 56, + 21, + 9, + []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + channelSummaries, + []string{ + "gpt-4.1", "gpt-4.1-mini", "o3", "o4-mini", "gemini-2.5-pro", "claude-3.7-sonnet", + "qwen-max", "deepseek-r1", "llama-3.3-70b", "mistral-large", "command-r-plus", "doubao-pro-32k", + "hunyuan-large", + }, + []string{ + "gpt-3.5-turbo", "claude-2.1", "gemini-1.5-pro", "mixtral-8x7b", "qwen-plus", "glm-4", + "yi-large", "moonshot-v1", "doubao-lite", + }, + ) + + require.Contains(t, content, "其余 4 个渠道已省略") + require.Contains(t, content, "其余 1 个已省略") + require.Contains(t, content, "失败渠道 ID(展示 10/12)") + require.Contains(t, content, "其余 2 个已省略") +} + +func TestShouldSendUpstreamModelUpdateNotification(t *testing.T) { + channelUpstreamModelUpdateNotifyState.Lock() + channelUpstreamModelUpdateNotifyState.lastNotifiedAt = 0 + channelUpstreamModelUpdateNotifyState.lastChangedChannels = 0 + channelUpstreamModelUpdateNotifyState.lastFailedChannels = 0 + channelUpstreamModelUpdateNotifyState.Unlock() + + baseTime := int64(2000000) + + require.True(t, shouldSendUpstreamModelUpdateNotification(baseTime, 6, 0)) + require.False(t, shouldSendUpstreamModelUpdateNotification(baseTime+3600, 6, 0)) + require.True(t, shouldSendUpstreamModelUpdateNotification(baseTime+3600, 7, 0)) + require.False(t, shouldSendUpstreamModelUpdateNotification(baseTime+7200, 7, 0)) + require.True(t, shouldSendUpstreamModelUpdateNotification(baseTime+8000, 0, 3)) + require.False(t, shouldSendUpstreamModelUpdateNotification(baseTime+9000, 0, 3)) + require.True(t, shouldSendUpstreamModelUpdateNotification(baseTime+10000, 0, 4)) + require.True(t, shouldSendUpstreamModelUpdateNotification(baseTime+90000, 7, 0)) + require.True(t, shouldSendUpstreamModelUpdateNotification(baseTime+90001, 0, 0)) +} diff --git a/controller/user.go b/controller/user.go index b58eab88f..4ec64e29e 100644 --- a/controller/user.go +++ b/controller/user.go @@ -1032,17 +1032,18 @@ func TopUp(c *gin.Context) { } type UpdateUserSettingRequest struct { - QuotaWarningType string `json:"notify_type"` - QuotaWarningThreshold float64 `json:"quota_warning_threshold"` - WebhookUrl string `json:"webhook_url,omitempty"` - WebhookSecret string `json:"webhook_secret,omitempty"` - NotificationEmail string `json:"notification_email,omitempty"` - BarkUrl string `json:"bark_url,omitempty"` - GotifyUrl string `json:"gotify_url,omitempty"` - GotifyToken string `json:"gotify_token,omitempty"` - GotifyPriority int `json:"gotify_priority,omitempty"` - AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"` - RecordIpLog bool `json:"record_ip_log"` + QuotaWarningType string `json:"notify_type"` + QuotaWarningThreshold float64 `json:"quota_warning_threshold"` + WebhookUrl string `json:"webhook_url,omitempty"` + WebhookSecret string `json:"webhook_secret,omitempty"` + NotificationEmail string `json:"notification_email,omitempty"` + BarkUrl string `json:"bark_url,omitempty"` + GotifyUrl string `json:"gotify_url,omitempty"` + GotifyToken string `json:"gotify_token,omitempty"` + GotifyPriority int `json:"gotify_priority,omitempty"` + UpstreamModelUpdateNotifyEnabled *bool `json:"upstream_model_update_notify_enabled,omitempty"` + AcceptUnsetModelRatioModel bool `json:"accept_unset_model_ratio_model"` + RecordIpLog bool `json:"record_ip_log"` } func UpdateUserSetting(c *gin.Context) { @@ -1132,13 +1133,19 @@ func UpdateUserSetting(c *gin.Context) { common.ApiError(c, err) return } + existingSettings := user.GetSetting() + upstreamModelUpdateNotifyEnabled := existingSettings.UpstreamModelUpdateNotifyEnabled + if user.Role >= common.RoleAdminUser && req.UpstreamModelUpdateNotifyEnabled != nil { + upstreamModelUpdateNotifyEnabled = *req.UpstreamModelUpdateNotifyEnabled + } // 构建设置 settings := dto.UserSetting{ - NotifyType: req.QuotaWarningType, - QuotaWarningThreshold: req.QuotaWarningThreshold, - AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel, - RecordIpLog: req.RecordIpLog, + NotifyType: req.QuotaWarningType, + QuotaWarningThreshold: req.QuotaWarningThreshold, + UpstreamModelUpdateNotifyEnabled: upstreamModelUpdateNotifyEnabled, + AcceptUnsetRatioModel: req.AcceptUnsetModelRatioModel, + RecordIpLog: req.RecordIpLog, } // 如果是webhook类型,添加webhook相关设置 diff --git a/dto/channel_settings.go b/dto/channel_settings.go index 72fdf460c..fc04937e7 100644 --- a/dto/channel_settings.go +++ b/dto/channel_settings.go @@ -24,16 +24,22 @@ const ( ) type ChannelOtherSettings struct { - AzureResponsesVersion string `json:"azure_responses_version,omitempty"` - VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" - OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` - ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true - AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) - AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规) - DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) - AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) - AllowIncludeObfuscation bool `json:"allow_include_obfuscation,omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护) - AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"` + AzureResponsesVersion string `json:"azure_responses_version,omitempty"` + VertexKeyType VertexKeyType `json:"vertex_key_type,omitempty"` // "json" or "api_key" + OpenRouterEnterprise *bool `json:"openrouter_enterprise,omitempty"` + ClaudeBetaQuery bool `json:"claude_beta_query,omitempty"` // Claude 渠道是否强制追加 ?beta=true + AllowServiceTier bool `json:"allow_service_tier,omitempty"` // 是否允许 service_tier 透传(默认过滤以避免额外计费) + AllowInferenceGeo bool `json:"allow_inference_geo,omitempty"` // 是否允许 inference_geo 透传(仅 Claude,默认过滤以满足数据驻留合规 + AllowSafetyIdentifier bool `json:"allow_safety_identifier,omitempty"` // 是否允许 safety_identifier 透传(默认过滤以保护用户隐私) + DisableStore bool `json:"disable_store,omitempty"` // 是否禁用 store 透传(默认允许透传,禁用后可能导致 Codex 无法使用) + AllowIncludeObfuscation bool `json:"allow_include_obfuscation, omitempty"` // 是否允许 stream_options.include_obfuscation 透传(默认过滤以避免关闭流混淆保护) + AwsKeyType AwsKeyType `json:"aws_key_type,omitempty"` + UpstreamModelUpdateCheckEnabled bool `json:"upstream_model_update_check_enabled,omitempty"` // 是否检测上游模型更新 + UpstreamModelUpdateAutoSyncEnabled bool `json:"upstream_model_update_auto_sync_enabled,omitempty"` // 是否自动同步上游模型更新 + UpstreamModelUpdateLastCheckTime int64 `json:"upstream_model_update_last_check_time,omitempty"` // 上次检测时间 + UpstreamModelUpdateLastDetectedModels []string `json:"upstream_model_update_last_detected_models,omitempty"` // 上次检测到的可加入模型 + UpstreamModelUpdateLastRemovedModels []string `json:"upstream_model_update_last_removed_models,omitempty"` // 上次检测到的可删除模型 + UpstreamModelUpdateIgnoredModels []string `json:"upstream_model_update_ignored_models,omitempty"` // 手动忽略的模型 } func (s *ChannelOtherSettings) IsOpenRouterEnterprise() bool { diff --git a/dto/user_settings.go b/dto/user_settings.go index 48411c86d..dbf555fad 100644 --- a/dto/user_settings.go +++ b/dto/user_settings.go @@ -1,20 +1,21 @@ package dto type UserSetting struct { - NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型 - QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值 - WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址 - WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥 - NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址 - BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL - GotifyUrl string `json:"gotify_url,omitempty"` // GotifyUrl Gotify服务器地址 - GotifyToken string `json:"gotify_token,omitempty"` // GotifyToken Gotify应用令牌 - GotifyPriority int `json:"gotify_priority"` // GotifyPriority Gotify消息优先级 - AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型 - RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP - SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置 - BillingPreference string `json:"billing_preference,omitempty"` // BillingPreference 扣费策略(订阅/钱包) - Language string `json:"language,omitempty"` // Language 用户语言偏好 (zh, en) + NotifyType string `json:"notify_type,omitempty"` // QuotaWarningType 额度预警类型 + QuotaWarningThreshold float64 `json:"quota_warning_threshold,omitempty"` // QuotaWarningThreshold 额度预警阈值 + WebhookUrl string `json:"webhook_url,omitempty"` // WebhookUrl webhook地址 + WebhookSecret string `json:"webhook_secret,omitempty"` // WebhookSecret webhook密钥 + NotificationEmail string `json:"notification_email,omitempty"` // NotificationEmail 通知邮箱地址 + BarkUrl string `json:"bark_url,omitempty"` // BarkUrl Bark推送URL + GotifyUrl string `json:"gotify_url,omitempty"` // GotifyUrl Gotify服务器地址 + GotifyToken string `json:"gotify_token,omitempty"` // GotifyToken Gotify应用令牌 + GotifyPriority int `json:"gotify_priority"` // GotifyPriority Gotify消息优先级 + UpstreamModelUpdateNotifyEnabled bool `json:"upstream_model_update_notify_enabled,omitempty"` // 是否接收上游模型更新定时检测通知(仅管理员) + AcceptUnsetRatioModel bool `json:"accept_unset_model_ratio_model,omitempty"` // AcceptUnsetRatioModel 是否接受未设置价格的模型 + RecordIpLog bool `json:"record_ip_log,omitempty"` // 是否记录请求和错误日志IP + SidebarModules string `json:"sidebar_modules,omitempty"` // SidebarModules 左侧边栏模块配置 + BillingPreference string `json:"billing_preference,omitempty"` // BillingPreference 扣费策略(订阅/钱包) + Language string `json:"language,omitempty"` // Language 用户语言偏好 (zh, en) } var ( diff --git a/main.go b/main.go index 476a2ed24..dbbf44a18 100644 --- a/main.go +++ b/main.go @@ -121,6 +121,9 @@ func main() { return a } + // Channel upstream model update check task + controller.StartChannelUpstreamModelUpdateTask() + if common.IsMasterNode && constant.UpdateTask { gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() diff --git a/router/api-router.go b/router/api-router.go index d48934000..fafb99575 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -237,6 +237,10 @@ func SetApiRouter(router *gin.Engine) { channelRoute.GET("/tag/models", controller.GetTagModels) channelRoute.POST("/copy/:id", controller.CopyChannel) channelRoute.POST("/multi_key/manage", controller.ManageMultiKeys) + channelRoute.POST("/upstream_updates/apply", controller.ApplyChannelUpstreamModelUpdates) + channelRoute.POST("/upstream_updates/apply_all", controller.ApplyAllChannelUpstreamModelUpdates) + channelRoute.POST("/upstream_updates/detect", controller.DetectChannelUpstreamModelUpdates) + channelRoute.POST("/upstream_updates/detect_all", controller.DetectAllChannelUpstreamModelUpdates) } tokenRoute := apiRouter.Group("/token") tokenRoute.Use(middleware.UserAuth()) diff --git a/service/task_billing_test.go b/service/task_billing_test.go index 1145bba54..79c8c49eb 100644 --- a/service/task_billing_test.go +++ b/service/task_billing_test.go @@ -125,8 +125,8 @@ func makeTask(userId, channelId, quota, tokenId int, billingSource string, subsc SubscriptionId: subscriptionId, TokenId: tokenId, BillingContext: &model.TaskBillingContext{ - ModelPrice: 0.02, - GroupRatio: 1.0, + ModelPrice: 0.02, + GroupRatio: 1.0, OriginModelName: "test-model", }, }, @@ -615,9 +615,11 @@ type mockAdaptor struct { adjustReturn int } -func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} -func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { return nil, nil } -func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } +func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} +func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { + return nil, nil +} +func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { return m.adjustReturn } diff --git a/service/user_notify.go b/service/user_notify.go index cecf46cad..27a72b8be 100644 --- a/service/user_notify.go +++ b/service/user_notify.go @@ -22,6 +22,32 @@ func NotifyRootUser(t string, subject string, content string) { } } +func NotifyUpstreamModelUpdateWatchers(subject string, content string) { + var users []model.User + if err := model.DB. + Select("id", "email", "role", "status", "setting"). + Where("status = ? AND role >= ?", common.UserStatusEnabled, common.RoleAdminUser). + Find(&users).Error; err != nil { + common.SysLog(fmt.Sprintf("failed to query upstream update notification users: %s", err.Error())) + return + } + + notification := dto.NewNotify(dto.NotifyTypeChannelUpdate, subject, content, nil) + sentCount := 0 + for _, user := range users { + userSetting := user.GetSetting() + if !userSetting.UpstreamModelUpdateNotifyEnabled { + continue + } + if err := NotifyUser(user.Id, user.Email, userSetting, notification); err != nil { + common.SysLog(fmt.Sprintf("failed to notify user %d for upstream model update: %s", user.Id, err.Error())) + continue + } + sentCount++ + } + common.SysLog(fmt.Sprintf("upstream model update notifications sent: %d", sentCount)) +} + func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data dto.Notify) error { notifyType := userSetting.NotifyType if notifyType == "" { diff --git a/web/src/components/settings/PersonalSetting.jsx b/web/src/components/settings/PersonalSetting.jsx index 8ee6415ac..aecafd44d 100644 --- a/web/src/components/settings/PersonalSetting.jsx +++ b/web/src/components/settings/PersonalSetting.jsx @@ -86,6 +86,7 @@ const PersonalSetting = () => { gotifyUrl: '', gotifyToken: '', gotifyPriority: 5, + upstreamModelUpdateNotifyEnabled: false, acceptUnsetModelRatioModel: false, recordIpLog: false, }); @@ -158,6 +159,8 @@ const PersonalSetting = () => { gotifyToken: settings.gotify_token || '', gotifyPriority: settings.gotify_priority !== undefined ? settings.gotify_priority : 5, + upstreamModelUpdateNotifyEnabled: + settings.upstream_model_update_notify_enabled === true, acceptUnsetModelRatioModel: settings.accept_unset_model_ratio_model || false, recordIpLog: settings.record_ip_log || false, @@ -426,6 +429,8 @@ const PersonalSetting = () => { const parsed = parseInt(notificationSettings.gotifyPriority); return isNaN(parsed) ? 5 : parsed; })(), + upstream_model_update_notify_enabled: + notificationSettings.upstreamModelUpdateNotifyEnabled === true, accept_unset_model_ratio_model: notificationSettings.acceptUnsetModelRatioModel, record_ip_log: notificationSettings.recordIpLog, diff --git a/web/src/components/settings/personal/cards/NotificationSettings.jsx b/web/src/components/settings/personal/cards/NotificationSettings.jsx index e57e39d63..5e8d4fd82 100644 --- a/web/src/components/settings/personal/cards/NotificationSettings.jsx +++ b/web/src/components/settings/personal/cards/NotificationSettings.jsx @@ -58,6 +58,7 @@ const NotificationSettings = ({ const formApiRef = useRef(null); const [statusState] = useContext(StatusContext); const [userState] = useContext(UserContext); + const isAdminOrRoot = (userState?.user?.role || 0) >= 10; // 左侧边栏设置相关状态 const [sidebarLoading, setSidebarLoading] = useState(false); @@ -470,6 +471,21 @@ const NotificationSettings = ({ ]} /> + {isAdminOrRoot && ( +