mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 09:55:01 +00:00
- 在创建记录前保存原始状态和同步官方字段值 - 使用独立的更新操作确保零值能够正确保存到数据库 - 修改更新方法使用 Select 强制更新所有字段包括零值 - 避免 GORM 默认行为对零值字段应用默认值导致数据丢失
161 lines
4.8 KiB
Go
161 lines
4.8 KiB
Go
package model
|
||
|
||
import (
|
||
"strconv"
|
||
|
||
"github.com/QuantumNous/new-api/common"
|
||
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const (
|
||
NameRuleExact = iota
|
||
NameRulePrefix
|
||
NameRuleContains
|
||
NameRuleSuffix
|
||
)
|
||
|
||
type BoundChannel struct {
|
||
Name string `json:"name"`
|
||
Type int `json:"type"`
|
||
}
|
||
|
||
type Model struct {
|
||
Id int `json:"id"`
|
||
ModelName string `json:"model_name" gorm:"size:128;not null;uniqueIndex:uk_model_name_delete_at,priority:1"`
|
||
Description string `json:"description,omitempty" gorm:"type:text"`
|
||
Icon string `json:"icon,omitempty" gorm:"type:varchar(128)"`
|
||
Tags string `json:"tags,omitempty" gorm:"type:varchar(255)"`
|
||
VendorID int `json:"vendor_id,omitempty" gorm:"index"`
|
||
Endpoints string `json:"endpoints,omitempty" gorm:"type:text"`
|
||
Status int `json:"status" gorm:"default:1"`
|
||
SyncOfficial int `json:"sync_official" gorm:"default:1"`
|
||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||
UpdatedTime int64 `json:"updated_time" gorm:"bigint"`
|
||
DeletedAt gorm.DeletedAt `json:"-" gorm:"index;uniqueIndex:uk_model_name_delete_at,priority:2"`
|
||
|
||
BoundChannels []BoundChannel `json:"bound_channels,omitempty" gorm:"-"`
|
||
EnableGroups []string `json:"enable_groups,omitempty" gorm:"-"`
|
||
QuotaTypes []int `json:"quota_types,omitempty" gorm:"-"`
|
||
NameRule int `json:"name_rule" gorm:"default:0"`
|
||
|
||
MatchedModels []string `json:"matched_models,omitempty" gorm:"-"`
|
||
MatchedCount int `json:"matched_count,omitempty" gorm:"-"`
|
||
}
|
||
|
||
func (mi *Model) Insert() error {
|
||
now := common.GetTimestamp()
|
||
mi.CreatedTime = now
|
||
mi.UpdatedTime = now
|
||
|
||
// 保存原始值(因为 Create 后可能被 GORM 的 default 标签覆盖为 1)
|
||
originalStatus := mi.Status
|
||
originalSyncOfficial := mi.SyncOfficial
|
||
|
||
// 先创建记录(GORM 会对零值字段应用默认值)
|
||
if err := DB.Create(mi).Error; err != nil {
|
||
return err
|
||
}
|
||
|
||
// 使用保存的原始值进行更新,确保零值能正确保存
|
||
return DB.Model(&Model{}).Where("id = ?", mi.Id).Updates(map[string]interface{}{
|
||
"status": originalStatus,
|
||
"sync_official": originalSyncOfficial,
|
||
}).Error
|
||
}
|
||
|
||
func IsModelNameDuplicated(id int, name string) (bool, error) {
|
||
if name == "" {
|
||
return false, nil
|
||
}
|
||
var cnt int64
|
||
err := DB.Model(&Model{}).Where("model_name = ? AND id <> ?", name, id).Count(&cnt).Error
|
||
return cnt > 0, err
|
||
}
|
||
|
||
func (mi *Model) Update() error {
|
||
mi.UpdatedTime = common.GetTimestamp()
|
||
// 使用 Select 强制更新所有字段,包括零值
|
||
return DB.Model(&Model{}).Where("id = ?", mi.Id).
|
||
Select("model_name", "description", "icon", "tags", "vendor_id", "endpoints", "status", "sync_official", "name_rule", "updated_time").
|
||
Updates(mi).Error
|
||
}
|
||
|
||
func (mi *Model) Delete() error {
|
||
return DB.Delete(mi).Error
|
||
}
|
||
|
||
func GetVendorModelCounts() (map[int64]int64, error) {
|
||
var stats []struct {
|
||
VendorID int64
|
||
Count int64
|
||
}
|
||
if err := DB.Model(&Model{}).
|
||
Select("vendor_id as vendor_id, count(*) as count").
|
||
Group("vendor_id").
|
||
Scan(&stats).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
m := make(map[int64]int64, len(stats))
|
||
for _, s := range stats {
|
||
m[s.VendorID] = s.Count
|
||
}
|
||
return m, nil
|
||
}
|
||
|
||
func GetAllModels(offset int, limit int) ([]*Model, error) {
|
||
var models []*Model
|
||
err := DB.Order("id DESC").Offset(offset).Limit(limit).Find(&models).Error
|
||
return models, err
|
||
}
|
||
|
||
func GetBoundChannelsByModelsMap(modelNames []string) (map[string][]BoundChannel, error) {
|
||
result := make(map[string][]BoundChannel)
|
||
if len(modelNames) == 0 {
|
||
return result, nil
|
||
}
|
||
type row struct {
|
||
Model string
|
||
Name string
|
||
Type int
|
||
}
|
||
var rows []row
|
||
err := DB.Table("channels").
|
||
Select("abilities.model as model, channels.name as name, channels.type as type").
|
||
Joins("JOIN abilities ON abilities.channel_id = channels.id").
|
||
Where("abilities.model IN ? AND abilities.enabled = ?", modelNames, true).
|
||
Distinct().
|
||
Scan(&rows).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, r := range rows {
|
||
result[r.Model] = append(result[r.Model], BoundChannel{Name: r.Name, Type: r.Type})
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func SearchModels(keyword string, vendor string, offset int, limit int) ([]*Model, int64, error) {
|
||
var models []*Model
|
||
db := DB.Model(&Model{})
|
||
if keyword != "" {
|
||
like := "%" + keyword + "%"
|
||
db = db.Where("model_name LIKE ? OR description LIKE ? OR tags LIKE ?", like, like, like)
|
||
}
|
||
if vendor != "" {
|
||
if vid, err := strconv.Atoi(vendor); err == nil {
|
||
db = db.Where("models.vendor_id = ?", vid)
|
||
} else {
|
||
db = db.Joins("JOIN vendors ON vendors.id = models.vendor_id").Where("vendors.name LIKE ?", "%"+vendor+"%")
|
||
}
|
||
}
|
||
var total int64
|
||
if err := db.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
if err := db.Order("models.id DESC").Offset(offset).Limit(limit).Find(&models).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
return models, total, nil
|
||
}
|