diff --git a/controller/model_sync.go b/controller/model_sync.go new file mode 100644 index 000000000..5e2803c5d --- /dev/null +++ b/controller/model_sync.go @@ -0,0 +1,463 @@ +package controller + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "strings" + "time" + + "one-api/model" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +// 上游地址 +const ( + upstreamModelsURL = "https://basellm.github.io/llm-metadata/api/newapi/models.json" + upstreamVendorsURL = "https://basellm.github.io/llm-metadata/api/newapi/vendors.json" +) + +type upstreamEnvelope[T any] struct { + Success bool `json:"success"` + Message string `json:"message"` + Data []T `json:"data"` +} + +type upstreamModel struct { + Description string `json:"description"` + Endpoints json.RawMessage `json:"endpoints"` + Icon string `json:"icon"` + ModelName string `json:"model_name"` + NameRule int `json:"name_rule"` + Status int `json:"status"` + Tags string `json:"tags"` + VendorName string `json:"vendor_name"` +} + +type upstreamVendor struct { + Description string `json:"description"` + Icon string `json:"icon"` + Name string `json:"name"` + Status int `json:"status"` +} + +type overwriteField struct { + ModelName string `json:"model_name"` + Fields []string `json:"fields"` +} + +type syncRequest struct { + Overwrite []overwriteField `json:"overwrite"` +} + +func newHTTPClient() *http.Client { + dialer := &net.Dialer{Timeout: 10 * time.Second} + transport := &http.Transport{ + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + } + transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + if strings.HasSuffix(host, "github.io") { + if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { + return conn, nil + } + return dialer.DialContext(ctx, "tcp6", addr) + } + return dialer.DialContext(ctx, network, addr) + } + return &http.Client{Transport: transport} +} + +var httpClient = newHTTPClient() + +func fetchJSON[T any](ctx context.Context, url string, out *upstreamEnvelope[T]) error { + var lastErr error + for attempt := 0; attempt < 3; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + resp, err := httpClient.Do(req) + if err != nil { + lastErr = err + time.Sleep(time.Duration(200*(1< id + vendorIDCache := make(map[string]int) + + for _, name := range missing { + up, ok := modelByName[name] + if !ok { + skipped = append(skipped, name) + continue + } + + // 若本地已存在且设置为不同步,则跳过(极端情况:缺失列表与本地状态不同步时) + var existing model.Model + if err := model.DB.Where("model_name = ?", name).First(&existing).Error; err == nil { + if existing.SyncOfficial == 0 { + skipped = append(skipped, name) + continue + } + } + + // 确保 vendor 存在 + vendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) + + // 创建模型 + mi := &model.Model{ + ModelName: name, + Description: up.Description, + Icon: up.Icon, + Tags: up.Tags, + VendorID: vendorID, + Status: chooseStatus(up.Status, 1), + NameRule: up.NameRule, + } + if err := mi.Insert(); err == nil { + createdModels++ + createdList = append(createdList, name) + } else { + skipped = append(skipped, name) + } + } + + // 4) 处理可选覆盖(更新本地已有模型的差异字段) + if len(req.Overwrite) > 0 { + // vendorIDCache 已用于创建阶段,可复用 + for _, ow := range req.Overwrite { + up, ok := modelByName[ow.ModelName] + if !ok { + continue + } + var local model.Model + if err := model.DB.Where("model_name = ?", ow.ModelName).First(&local).Error; err != nil { + continue + } + + // 跳过被禁用官方同步的模型 + if local.SyncOfficial == 0 { + continue + } + + // 映射 vendor + newVendorID := ensureVendorID(up.VendorName, vendorByName, vendorIDCache, &createdVendors) + + // 应用字段覆盖(事务) + _ = model.DB.Transaction(func(tx *gorm.DB) error { + needUpdate := false + if containsField(ow.Fields, "description") { + local.Description = up.Description + needUpdate = true + } + if containsField(ow.Fields, "icon") { + local.Icon = up.Icon + needUpdate = true + } + if containsField(ow.Fields, "tags") { + local.Tags = up.Tags + needUpdate = true + } + if containsField(ow.Fields, "vendor") { + local.VendorID = newVendorID + needUpdate = true + } + if containsField(ow.Fields, "name_rule") { + local.NameRule = up.NameRule + needUpdate = true + } + if containsField(ow.Fields, "status") { + local.Status = chooseStatus(up.Status, local.Status) + needUpdate = true + } + if !needUpdate { + return nil + } + if err := tx.Save(&local).Error; err != nil { + return err + } + updatedModels++ + updatedList = append(updatedList, ow.ModelName) + return nil + }) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "created_models": createdModels, + "created_vendors": createdVendors, + "updated_models": updatedModels, + "skipped_models": skipped, + "created_list": createdList, + "updated_list": updatedList, + }, + }) +} + +func containsField(fields []string, key string) bool { + key = strings.ToLower(strings.TrimSpace(key)) + for _, f := range fields { + if strings.ToLower(strings.TrimSpace(f)) == key { + return true + } + } + return false +} + +func coalesce(a, b string) string { + if strings.TrimSpace(a) != "" { + return a + } + return b +} + +func chooseStatus(primary, fallback int) int { + if primary == 0 && fallback != 0 { + return fallback + } + if primary != 0 { + return primary + } + return 1 +} + +// SyncUpstreamPreview 预览上游与本地的差异(仅用于弹窗选择) +func SyncUpstreamPreview(c *gin.Context) { + // 1) 拉取上游数据 + ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second) + defer cancel() + + var vendorsEnv upstreamEnvelope[upstreamVendor] + _ = fetchJSON(ctx, upstreamVendorsURL, &vendorsEnv) + + var modelsEnv upstreamEnvelope[upstreamModel] + if err := fetchJSON(ctx, upstreamModelsURL, &modelsEnv); err != nil { + c.JSON(http.StatusOK, gin.H{"success": false, "message": "获取上游模型失败: " + err.Error()}) + return + } + + vendorByName := make(map[string]upstreamVendor) + for _, v := range vendorsEnv.Data { + if v.Name != "" { + vendorByName[v.Name] = v + } + } + modelByName := make(map[string]upstreamModel) + upstreamNames := make([]string, 0, len(modelsEnv.Data)) + for _, m := range modelsEnv.Data { + if m.ModelName != "" { + modelByName[m.ModelName] = m + upstreamNames = append(upstreamNames, m.ModelName) + } + } + + // 2) 本地已有模型 + var locals []model.Model + if len(upstreamNames) > 0 { + _ = model.DB.Where("model_name IN ? AND sync_official <> 0", upstreamNames).Find(&locals).Error + } + + // 本地 vendor 名称映射 + vendorIdSet := make(map[int]struct{}) + for _, m := range locals { + if m.VendorID != 0 { + vendorIdSet[m.VendorID] = struct{}{} + } + } + vendorIDs := make([]int, 0, len(vendorIdSet)) + for id := range vendorIdSet { + vendorIDs = append(vendorIDs, id) + } + idToVendorName := make(map[int]string) + if len(vendorIDs) > 0 { + var dbVendors []model.Vendor + _ = model.DB.Where("id IN ?", vendorIDs).Find(&dbVendors).Error + for _, v := range dbVendors { + idToVendorName[v.Id] = v.Name + } + } + + // 3) 缺失且上游存在的模型 + missingList, _ := model.GetMissingModels() + var missing []string + for _, name := range missingList { + if _, ok := modelByName[name]; ok { + missing = append(missing, name) + } + } + + // 4) 计算冲突字段 + type conflictField struct { + Field string `json:"field"` + Local interface{} `json:"local"` + Upstream interface{} `json:"upstream"` + } + type conflictItem struct { + ModelName string `json:"model_name"` + Fields []conflictField `json:"fields"` + } + + var conflicts []conflictItem + for _, local := range locals { + up, ok := modelByName[local.ModelName] + if !ok { + continue + } + fields := make([]conflictField, 0, 6) + if strings.TrimSpace(local.Description) != strings.TrimSpace(up.Description) { + fields = append(fields, conflictField{Field: "description", Local: local.Description, Upstream: up.Description}) + } + if strings.TrimSpace(local.Icon) != strings.TrimSpace(up.Icon) { + fields = append(fields, conflictField{Field: "icon", Local: local.Icon, Upstream: up.Icon}) + } + if strings.TrimSpace(local.Tags) != strings.TrimSpace(up.Tags) { + fields = append(fields, conflictField{Field: "tags", Local: local.Tags, Upstream: up.Tags}) + } + // vendor 对比使用名称 + localVendor := idToVendorName[local.VendorID] + if strings.TrimSpace(localVendor) != strings.TrimSpace(up.VendorName) { + fields = append(fields, conflictField{Field: "vendor", Local: localVendor, Upstream: up.VendorName}) + } + if local.NameRule != up.NameRule { + fields = append(fields, conflictField{Field: "name_rule", Local: local.NameRule, Upstream: up.NameRule}) + } + if local.Status != chooseStatus(up.Status, local.Status) { + fields = append(fields, conflictField{Field: "status", Local: local.Status, Upstream: up.Status}) + } + if len(fields) > 0 { + conflicts = append(conflicts, conflictItem{ModelName: local.ModelName, Fields: fields}) + } + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "data": gin.H{ + "missing": missing, + "conflicts": conflicts, + }, + }) +} + + diff --git a/model/model_meta.go b/model/model_meta.go index e9582e441..a6230553b 100644 --- a/model/model_meta.go +++ b/model/model_meta.go @@ -28,6 +28,7 @@ type Model struct { VendorID int `json:"vendor_id,omitempty" gorm:"index"` Endpoints string `json:"endpoints,omitempty" gorm:"type:text"` Status int `json:"status" gorm:"default:1"` + SyncOfficial int `json:"sync_official" gorm:"default:1"` CreatedTime int64 `json:"created_time" gorm:"bigint"` UpdatedTime int64 `json:"updated_time" gorm:"bigint"` DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"` diff --git a/router/api-router.go b/router/api-router.go index 311bb0a4b..773857385 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -224,6 +224,8 @@ func SetApiRouter(router *gin.Engine) { modelsRoute := apiRouter.Group("/models") modelsRoute.Use(middleware.AdminAuth()) { + modelsRoute.GET("/sync_upstream/preview", controller.SyncUpstreamPreview) + modelsRoute.POST("/sync_upstream", controller.SyncUpstreamModels) modelsRoute.GET("/missing", controller.GetMissingModels) modelsRoute.GET("/", controller.GetAllModelsMeta) modelsRoute.GET("/search", controller.SearchModelsMeta) diff --git a/web/src/components/table/models/ModelsActions.jsx b/web/src/components/table/models/ModelsActions.jsx index cc6b8ef8e..cc6c9afed 100644 --- a/web/src/components/table/models/ModelsActions.jsx +++ b/web/src/components/table/models/ModelsActions.jsx @@ -21,10 +21,11 @@ import React, { useState } from 'react'; import MissingModelsModal from './modals/MissingModelsModal'; import PrefillGroupManagement from './modals/PrefillGroupManagement'; import EditPrefillGroupModal from './modals/EditPrefillGroupModal'; -import { Button, Modal } from '@douyinfe/semi-ui'; +import { Button, Modal, Popover } from '@douyinfe/semi-ui'; import { showSuccess, showError, copy } from '../../../helpers'; import CompactModeToggle from '../../common/ui/CompactModeToggle'; import SelectionNotification from './components/SelectionNotification'; +import UpstreamConflictModal from './modals/UpstreamConflictModal'; const ModelsActions = ({ selectedKeys, @@ -32,6 +33,11 @@ const ModelsActions = ({ setEditingModel, setShowEdit, batchDeleteModels, + syncing, + previewing, + syncUpstream, + previewUpstreamDiff, + applyUpstreamOverwrite, compactMode, setCompactMode, t, @@ -42,6 +48,21 @@ const ModelsActions = ({ const [showGroupManagement, setShowGroupManagement] = useState(false); const [showAddPrefill, setShowAddPrefill] = useState(false); const [prefillInit, setPrefillInit] = useState({ id: undefined }); + const [showConflict, setShowConflict] = useState(false); + const [conflicts, setConflicts] = useState([]); + + const handleSyncUpstream = async () => { + // 先预览 + const data = await previewUpstreamDiff?.(); + const conflictItems = data?.conflicts || []; + if (conflictItems.length > 0) { + setConflicts(conflictItems); + setShowConflict(true); + return; + } + // 无冲突,直接同步缺失 + await syncUpstream?.(); + }; // Handle delete selected models with confirmation const handleDeleteSelectedModels = () => { @@ -104,6 +125,38 @@ const ModelsActions = ({ {t('未配置模型')} + +
+ {t( + '模型社区需要大家的共同维护,如发现数据有误或想贡献新的模型数据,请访问:', + )} +
+ + https://github.com/basellm/llm-metadata + + + } + > + +
+