Files
new-api/model/main.go

705 lines
20 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package model
import (
"fmt"
"log"
"os"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/glebarez/sqlite"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
var commonGroupCol string
var commonKeyCol string
var commonTrueVal string
var commonFalseVal string
var logKeyCol string
var logGroupCol string
func initCol() {
// init common column names
if common.UsingPostgreSQL {
commonGroupCol = `"group"`
commonKeyCol = `"key"`
commonTrueVal = "true"
commonFalseVal = "false"
} else {
commonGroupCol = "`group`"
commonKeyCol = "`key`"
commonTrueVal = "1"
commonFalseVal = "0"
}
if os.Getenv("LOG_SQL_DSN") != "" {
switch common.LogSqlType {
case common.DatabaseTypePostgreSQL:
logGroupCol = `"group"`
logKeyCol = `"key"`
default:
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
} else {
// LOG_SQL_DSN 为空时,日志数据库与主数据库相同
if common.UsingPostgreSQL {
logGroupCol = `"group"`
logKeyCol = `"key"`
} else {
logGroupCol = commonGroupCol
logKeyCol = commonKeyCol
}
}
// log sql type and database type
//common.SysLog("Using Log SQL Type: " + common.LogSqlType)
}
var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error {
var user User
//if user.Status != common.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456")
if err != nil {
return err
}
rootUser := User{
Username: "root",
Password: hashedPassword,
Role: common.RoleRootUser,
Status: common.UserStatusEnabled,
DisplayName: "Root User",
AccessToken: nil,
Quota: 100000000,
}
DB.Create(&rootUser)
}
return nil
}
func CheckSetup() {
setup := GetSetup()
if setup == nil {
// No setup record exists, check if we have a root user
if RootUserExists() {
common.SysLog("system is not initialized, but root user exists")
// Create setup record
newSetup := Setup{
Version: common.Version,
InitializedAt: time.Now().Unix(),
}
err := DB.Create(&newSetup).Error
if err != nil {
common.SysLog("failed to create setup record: " + err.Error())
}
constant.Setup = true
} else {
common.SysLog("system is not initialized and no root user exists")
constant.Setup = false
}
} else {
// Setup record exists, system is initialized
common.SysLog("system is already initialized at: " + time.Unix(setup.InitializedAt, 0).String())
constant.Setup = true
}
}
func chooseDB(envName string, isLog bool) (*gorm.DB, error) {
defer func() {
initCol()
}()
dsn := os.Getenv(envName)
if dsn != "" {
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
if !isLog {
common.UsingPostgreSQL = true
} else {
common.LogSqlType = common.DatabaseTypePostgreSQL
}
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
if strings.HasPrefix(dsn, "local") {
common.SysLog("SQL_DSN not set, using SQLite as database")
if !isLog {
common.UsingSQLite = true
} else {
common.LogSqlType = common.DatabaseTypeSQLite
}
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL
common.SysLog("using MySQL as database")
// check parseTime
if !strings.Contains(dsn, "parseTime") {
if strings.Contains(dsn, "?") {
dsn += "&parseTime=true"
} else {
dsn += "?parseTime=true"
}
}
if !isLog {
common.UsingMySQL = true
} else {
common.LogSqlType = common.DatabaseTypeMySQL
}
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
func InitDB() (err error) {
db, err := chooseDB("SQL_DSN", false)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
DB = db
// MySQL charset/collation startup check: ensure Chinese-capable charset
if common.UsingMySQL {
if err := checkMySQLChineseSupport(DB); err != nil {
panic(err)
}
}
sqlDB, err := DB.DB()
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
if common.UsingMySQL {
//_, _ = sqlDB.Exec("ALTER TABLE channels MODIFY model_mapping TEXT;") // TODO: delete this line when most users have upgraded
}
common.SysLog("database migration started")
err = migrateDB()
return err
} else {
common.FatalLog(err)
}
return err
}
func InitLogDB() (err error) {
if os.Getenv("LOG_SQL_DSN") == "" {
LOG_DB = DB
return
}
db, err := chooseDB("LOG_SQL_DSN", true)
if err == nil {
if common.DebugEnabled {
db = db.Debug()
}
LOG_DB = db
// If log DB is MySQL, also ensure Chinese-capable charset
if common.LogSqlType == common.DatabaseTypeMySQL {
if err := checkMySQLChineseSupport(LOG_DB); err != nil {
panic(err)
}
}
sqlDB, err := LOG_DB.DB()
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60)))
if !common.IsMasterNode {
return nil
}
common.SysLog("database migration started")
err = migrateLOGDB()
return err
} else {
common.FatalLog(err)
}
return err
}
func migrateDB() error {
// Migrate price_amount column from float/double to decimal for existing tables
migrateSubscriptionPlanPriceAmount()
// Migrate model_limits column from varchar to text for existing tables
if err := migrateTokenModelLimitsToText(); err != nil {
return err
}
err := DB.AutoMigrate(
&Channel{},
&Token{},
&User{},
&PasskeyCredential{},
&Option{},
&Redemption{},
&Ability{},
&Log{},
&Midjourney{},
&TopUp{},
&QuotaData{},
&Task{},
&Model{},
&Vendor{},
&PrefillGroup{},
&Setup{},
&TwoFA{},
&TwoFABackupCode{},
&Checkin{},
&SubscriptionOrder{},
&UserSubscription{},
&SubscriptionPreConsumeRecord{},
&CustomOAuthProvider{},
&UserOAuthBinding{},
)
if err != nil {
return err
}
if common.UsingSQLite {
if err := ensureSubscriptionPlanTableSQLite(); err != nil {
return err
}
} else {
if err := DB.AutoMigrate(&SubscriptionPlan{}); err != nil {
return err
}
}
return nil
}
func migrateDBFast() error {
var wg sync.WaitGroup
migrations := []struct {
model interface{}
name string
}{
{&Channel{}, "Channel"},
{&Token{}, "Token"},
{&User{}, "User"},
{&PasskeyCredential{}, "PasskeyCredential"},
{&Option{}, "Option"},
{&Redemption{}, "Redemption"},
{&Ability{}, "Ability"},
{&Log{}, "Log"},
{&Midjourney{}, "Midjourney"},
{&TopUp{}, "TopUp"},
{&QuotaData{}, "QuotaData"},
{&Task{}, "Task"},
{&Model{}, "Model"},
{&Vendor{}, "Vendor"},
{&PrefillGroup{}, "PrefillGroup"},
{&Setup{}, "Setup"},
{&TwoFA{}, "TwoFA"},
{&TwoFABackupCode{}, "TwoFABackupCode"},
{&Checkin{}, "Checkin"},
{&SubscriptionOrder{}, "SubscriptionOrder"},
{&UserSubscription{}, "UserSubscription"},
{&SubscriptionPreConsumeRecord{}, "SubscriptionPreConsumeRecord"},
{&CustomOAuthProvider{}, "CustomOAuthProvider"},
{&UserOAuthBinding{}, "UserOAuthBinding"},
}
// 动态计算migration数量确保errChan缓冲区足够大
errChan := make(chan error, len(migrations))
for _, m := range migrations {
wg.Add(1)
go func(model interface{}, name string) {
defer wg.Done()
if err := DB.AutoMigrate(model); err != nil {
errChan <- fmt.Errorf("failed to migrate %s: %v", name, err)
}
}(m.model, m.name)
}
// Wait for all migrations to complete
wg.Wait()
close(errChan)
// Check for any errors
for err := range errChan {
if err != nil {
return err
}
}
if common.UsingSQLite {
if err := ensureSubscriptionPlanTableSQLite(); err != nil {
return err
}
} else {
if err := DB.AutoMigrate(&SubscriptionPlan{}); err != nil {
return err
}
}
common.SysLog("database migrated")
return nil
}
func migrateLOGDB() error {
var err error
if err = LOG_DB.AutoMigrate(&Log{}); err != nil {
return err
}
return nil
}
type sqliteColumnDef struct {
Name string
DDL string
}
func ensureSubscriptionPlanTableSQLite() error {
if !common.UsingSQLite {
return nil
}
tableName := "subscription_plans"
if !DB.Migrator().HasTable(tableName) {
createSQL := `CREATE TABLE ` + "`" + tableName + "`" + ` (
` + "`id`" + ` integer,
` + "`title`" + ` varchar(128) NOT NULL,
` + "`subtitle`" + ` varchar(255) DEFAULT '',
` + "`price_amount`" + ` decimal(10,6) NOT NULL,
` + "`currency`" + ` varchar(8) NOT NULL DEFAULT 'USD',
` + "`duration_unit`" + ` varchar(16) NOT NULL DEFAULT 'month',
` + "`duration_value`" + ` integer NOT NULL DEFAULT 1,
` + "`custom_seconds`" + ` bigint NOT NULL DEFAULT 0,
` + "`enabled`" + ` numeric DEFAULT 1,
` + "`sort_order`" + ` integer DEFAULT 0,
` + "`stripe_price_id`" + ` varchar(128) DEFAULT '',
` + "`creem_product_id`" + ` varchar(128) DEFAULT '',
` + "`max_purchase_per_user`" + ` integer DEFAULT 0,
` + "`upgrade_group`" + ` varchar(64) DEFAULT '',
` + "`total_amount`" + ` bigint NOT NULL DEFAULT 0,
` + "`quota_reset_period`" + ` varchar(16) DEFAULT 'never',
` + "`quota_reset_custom_seconds`" + ` bigint DEFAULT 0,
` + "`created_at`" + ` bigint,
` + "`updated_at`" + ` bigint,
PRIMARY KEY (` + "`id`" + `)
)`
return DB.Exec(createSQL).Error
}
var cols []struct {
Name string `gorm:"column:name"`
}
if err := DB.Raw("PRAGMA table_info(`" + tableName + "`)").Scan(&cols).Error; err != nil {
return err
}
existing := make(map[string]struct{}, len(cols))
for _, c := range cols {
existing[c.Name] = struct{}{}
}
required := []sqliteColumnDef{
{Name: "title", DDL: "`title` varchar(128) NOT NULL"},
{Name: "subtitle", DDL: "`subtitle` varchar(255) DEFAULT ''"},
{Name: "price_amount", DDL: "`price_amount` decimal(10,6) NOT NULL"},
{Name: "currency", DDL: "`currency` varchar(8) NOT NULL DEFAULT 'USD'"},
{Name: "duration_unit", DDL: "`duration_unit` varchar(16) NOT NULL DEFAULT 'month'"},
{Name: "duration_value", DDL: "`duration_value` integer NOT NULL DEFAULT 1"},
{Name: "custom_seconds", DDL: "`custom_seconds` bigint NOT NULL DEFAULT 0"},
{Name: "enabled", DDL: "`enabled` numeric DEFAULT 1"},
{Name: "sort_order", DDL: "`sort_order` integer DEFAULT 0"},
{Name: "stripe_price_id", DDL: "`stripe_price_id` varchar(128) DEFAULT ''"},
{Name: "creem_product_id", DDL: "`creem_product_id` varchar(128) DEFAULT ''"},
{Name: "max_purchase_per_user", DDL: "`max_purchase_per_user` integer DEFAULT 0"},
{Name: "upgrade_group", DDL: "`upgrade_group` varchar(64) DEFAULT ''"},
{Name: "total_amount", DDL: "`total_amount` bigint NOT NULL DEFAULT 0"},
{Name: "quota_reset_period", DDL: "`quota_reset_period` varchar(16) DEFAULT 'never'"},
{Name: "quota_reset_custom_seconds", DDL: "`quota_reset_custom_seconds` bigint DEFAULT 0"},
{Name: "created_at", DDL: "`created_at` bigint"},
{Name: "updated_at", DDL: "`updated_at` bigint"},
}
for _, col := range required {
if _, ok := existing[col.Name]; ok {
continue
}
if err := DB.Exec("ALTER TABLE `" + tableName + "` ADD COLUMN " + col.DDL).Error; err != nil {
return err
}
}
return nil
}
// migrateTokenModelLimitsToText migrates model_limits column from varchar(1024) to text
// This is safe to run multiple times - it checks the column type first
func migrateTokenModelLimitsToText() error {
// SQLite uses type affinity, so TEXT and VARCHAR are effectively the same — no migration needed
if common.UsingSQLite {
return nil
}
tableName := "tokens"
columnName := "model_limits"
if !DB.Migrator().HasTable(tableName) {
return nil
}
if !DB.Migrator().HasColumn(&Token{}, columnName) {
return nil
}
var alterSQL string
if common.UsingPostgreSQL {
var dataType string
if err := DB.Raw(`SELECT data_type FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
tableName, columnName).Scan(&dataType).Error; err != nil {
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
} else if dataType == "text" {
return nil
}
alterSQL = fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s TYPE text`, tableName, columnName)
} else if common.UsingMySQL {
var columnType string
if err := DB.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
tableName, columnName).Scan(&columnType).Error; err != nil {
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
} else if strings.ToLower(columnType) == "text" {
return nil
}
alterSQL = fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s text", tableName, columnName)
} else {
return nil
}
if alterSQL != "" {
if err := DB.Exec(alterSQL).Error; err != nil {
return fmt.Errorf("failed to migrate %s.%s to text: %w", tableName, columnName, err)
}
common.SysLog(fmt.Sprintf("Successfully migrated %s.%s to text", tableName, columnName))
}
return nil
}
// migrateSubscriptionPlanPriceAmount migrates price_amount column from float/double to decimal(10,6)
// This is safe to run multiple times - it checks the column type first
func migrateSubscriptionPlanPriceAmount() {
// SQLite doesn't support ALTER COLUMN, and its type affinity handles this automatically
// Skip early to avoid GORM parsing the existing table DDL which may cause issues
if common.UsingSQLite {
return
}
tableName := "subscription_plans"
columnName := "price_amount"
// Check if table exists first
if !DB.Migrator().HasTable(tableName) {
return
}
// Check if column exists
if !DB.Migrator().HasColumn(&SubscriptionPlan{}, columnName) {
return
}
var alterSQL string
if common.UsingPostgreSQL {
// PostgreSQL: Check if already decimal/numeric
var dataType string
if err := DB.Raw(`SELECT data_type FROM information_schema.columns
WHERE table_schema = current_schema() AND table_name = ? AND column_name = ?`,
tableName, columnName).Scan(&dataType).Error; err != nil {
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
} else if dataType == "numeric" {
return // Already decimal/numeric
}
alterSQL = fmt.Sprintf(`ALTER TABLE %s ALTER COLUMN %s TYPE decimal(10,6) USING %s::decimal(10,6)`,
tableName, columnName, columnName)
} else if common.UsingMySQL {
// MySQL: Check if already decimal
var columnType string
if err := DB.Raw(`SELECT COLUMN_TYPE FROM information_schema.columns
WHERE table_schema = DATABASE() AND table_name = ? AND column_name = ?`,
tableName, columnName).Scan(&columnType).Error; err != nil {
common.SysLog(fmt.Sprintf("Warning: failed to query metadata for %s.%s: %v", tableName, columnName, err))
} else if strings.HasPrefix(strings.ToLower(columnType), "decimal") {
return // Already decimal
}
alterSQL = fmt.Sprintf("ALTER TABLE %s MODIFY COLUMN %s decimal(10,6) NOT NULL DEFAULT 0",
tableName, columnName)
} else {
return
}
if alterSQL != "" {
if err := DB.Exec(alterSQL).Error; err != nil {
common.SysLog(fmt.Sprintf("Warning: failed to migrate %s.%s to decimal: %v", tableName, columnName, err))
} else {
common.SysLog(fmt.Sprintf("Successfully migrated %s.%s to decimal(10,6)", tableName, columnName))
}
}
}
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return err
}
err = sqlDB.Close()
return err
}
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}
// checkMySQLChineseSupport ensures the MySQL connection and current schema
// default charset/collation can store Chinese characters. It allows common
// Chinese-capable charsets (utf8mb4, utf8, gbk, big5, gb18030) and panics otherwise.
func checkMySQLChineseSupport(db *gorm.DB) error {
// 仅检测:当前库默认字符集/排序规则 + 各表的排序规则(隐含字符集)
// Read current schema defaults
var schemaCharset, schemaCollation string
err := db.Raw("SELECT DEFAULT_CHARACTER_SET_NAME, DEFAULT_COLLATION_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = DATABASE()").Row().Scan(&schemaCharset, &schemaCollation)
if err != nil {
return fmt.Errorf("读取当前库默认字符集/排序规则失败 / Failed to read schema default charset/collation: %v", err)
}
toLower := func(s string) string { return strings.ToLower(s) }
// Allowed charsets that can store Chinese text
allowedCharsets := map[string]string{
"utf8mb4": "utf8mb4_",
"utf8": "utf8_",
"gbk": "gbk_",
"big5": "big5_",
"gb18030": "gb18030_",
}
isChineseCapable := func(cs, cl string) bool {
csLower := toLower(cs)
clLower := toLower(cl)
if prefix, ok := allowedCharsets[csLower]; ok {
if clLower == "" {
return true
}
return strings.HasPrefix(clLower, prefix)
}
// 如果仅提供了排序规则,尝试按排序规则前缀判断
for _, prefix := range allowedCharsets {
if strings.HasPrefix(clLower, prefix) {
return true
}
}
return false
}
// 1) 当前库默认值必须支持中文
if !isChineseCapable(schemaCharset, schemaCollation) {
return fmt.Errorf("当前库默认字符集/排序规则不支持中文schema(%s/%s)。请将库设置为 utf8mb4/utf8/gbk/big5/gb18030 / Schema default charset/collation is not Chinese-capable: schema(%s/%s). Please set to utf8mb4/utf8/gbk/big5/gb18030",
schemaCharset, schemaCollation, schemaCharset, schemaCollation)
}
// 2) 所有物理表的排序规则(隐含字符集)必须支持中文
type tableInfo struct {
Name string
Collation *string
}
var tables []tableInfo
if err := db.Raw("SELECT TABLE_NAME, TABLE_COLLATION FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE() AND TABLE_TYPE = 'BASE TABLE'").Scan(&tables).Error; err != nil {
return fmt.Errorf("读取表排序规则失败 / Failed to read table collations: %v", err)
}
var badTables []string
for _, t := range tables {
// NULL 或空表示继承库默认设置,已在上面校验库默认,视为通过
if t.Collation == nil || *t.Collation == "" {
continue
}
cl := *t.Collation
// 仅凭排序规则判断是否中文可用
ok := false
lower := strings.ToLower(cl)
for _, prefix := range allowedCharsets {
if strings.HasPrefix(lower, prefix) {
ok = true
break
}
}
if !ok {
badTables = append(badTables, fmt.Sprintf("%s(%s)", t.Name, cl))
}
}
if len(badTables) > 0 {
// 限制输出数量以避免日志过长
maxShow := 20
shown := badTables
if len(shown) > maxShow {
shown = shown[:maxShow]
}
return fmt.Errorf(
"存在不支持中文的表,请修复其排序规则/字符集。示例(最多展示 %d 项):%v / Found tables not Chinese-capable. Please fix their collation/charset. Examples (showing up to %d): %v",
maxShow, shown, maxShow, shown,
)
}
return nil
}
var (
lastPingTime time.Time
pingMutex sync.Mutex
)
func PingDB() error {
pingMutex.Lock()
defer pingMutex.Unlock()
if time.Since(lastPingTime) < time.Second*10 {
return nil
}
sqlDB, err := DB.DB()
if err != nil {
log.Printf("Error getting sql.DB from GORM: %v", err)
return err
}
err = sqlDB.Ping()
if err != nil {
log.Printf("Error pinging DB: %v", err)
return err
}
lastPingTime = time.Now()
common.SysLog("Database pinged successfully")
return nil
}