Files
new-api/model/model_meta.go
wans10 3229b81149 fix(model): 解决模型创建和更新时零值字段被默认值覆盖的问题
- 在创建记录前保存原始状态和同步官方字段值
- 使用独立的更新操作确保零值能够正确保存到数据库
- 修改更新方法使用 Select 强制更新所有字段包括零值
- 避免 GORM 默认行为对零值字段应用默认值导致数据丢失
2026-02-03 13:32:14 +08:00

161 lines
4.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package model
import (
"strconv"
"github.com/QuantumNous/new-api/common"
"gorm.io/gorm"
)
const (
NameRuleExact = iota
NameRulePrefix
NameRuleContains
NameRuleSuffix
)
type BoundChannel struct {
Name string `json:"name"`
Type int `json:"type"`
}
type Model struct {
Id int `json:"id"`
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
Description string `json:"description,omitempty" gorm:"type:text"`
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
Status int `json:"status" gorm:"default:1"`
SyncOfficial int `json:"sync_official" gorm:"default:1"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
NameRule int `json:"name_rule" gorm:"default:0"`
MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
}
func (mi *Model) Insert() error {
now := common.GetTimestamp()
mi.CreatedTime = now
mi.UpdatedTime = now
// 保存原始值(因为 Create 后可能被 GORM 的 default 标签覆盖为 1
originalStatus := mi.Status
originalSyncOfficial := mi.SyncOfficial
// 先创建记录GORM 会对零值字段应用默认值)
if err := DB.Create(mi).Error; err != nil {
return err
}
// 使用保存的原始值进行更新,确保零值能正确保存
return DB.Model(&Model{}).Where("id = ?", mi.Id).Updates(map[string]interface{}{
"status": originalStatus,
"sync_official": originalSyncOfficial,
}).Error
}
func IsModelNameDuplicated(id int, name string) (bool, error) {
if name == "" {
return false, nil
}
var cnt int64
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
return cnt > 0, err
}
func (mi *Model) Update() error {
mi.UpdatedTime = common.GetTimestamp()
// 使用 Select 强制更新所有字段,包括零值
return DB.Model(&Model{}).Where("id = ?", mi.Id).
Select("model_name", "description", "icon", "tags", "vendor_id", "endpoints", "status", "sync_official", "name_rule", "updated_time").
Updates(mi).Error
}
func (mi *Model) Delete() error {
return DB.Delete(mi).Error
}
func GetVendorModelCounts() (map[int64]int64, error) {
var stats []struct {
VendorID int64
Count int64
}
if err := DB.Model(&Model{}).
Select("vendor_id as vendor_id, count(*) as count").
Group("vendor_id").
Scan(&stats).Error; err != nil {
return nil, err
}
m := make(map[int64]int64, len(stats))
for _, s := range stats {
m[s.VendorID] = s.Count
}
return m, nil
}
func GetAllModels(offset int, limit int) ([]*Model, error) {
var models []*Model
err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
return models, err
}
func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
result := make(map[string][]BoundChannel)
if len(modelNames) == 0 {
return result, nil
}
type row struct {
Model string
Name string
Type int
}
var rows []row
err := DB.Table("channels").
Select("abilities.model as model, channels.name as name, channels.type as type").
Joins("JOIN abilities ON abilities.channel_id = channels.id").
Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
Distinct().
Scan(&rows).Error
if err != nil {
return nil, err
}
for _, r := range rows {
result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
}
return result, nil
}
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
var models []*Model
db := DB.Model(&Model{})
if keyword != "" {
like := "%" + keyword + "%"
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
}
if vendor != "" {
if vid, err := strconv.Atoi(vendor); err == nil {
db = db.Where("models.vendor_id = ?", vid)
} else {
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
}
}
var total int64
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
return nil, 0, err
}
return models, total, nil
}