mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 13:50:57 +00:00
- Updated UpdateWithStatus method to use Model().Select("*").Updates() for conditional updates, preventing GORM's INSERT fallback.
- Introduced comprehensive integration tests for UpdateWithStatus, covering scenarios for winning and losing CAS updates, as well as concurrent updates.
- Added task_cas_test.go to validate the new behavior and ensure data integrity during concurrent state transitions.
218 lines
5.5 KiB
Go
218 lines
5.5 KiB
Go
package model
|
|
|
|
import (
|
|
"encoding/json"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
"github.com/glebarez/sqlite"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
|
if err != nil {
|
|
panic("failed to open test db: " + err.Error())
|
|
}
|
|
DB = db
|
|
LOG_DB = db
|
|
|
|
common.UsingSQLite = true
|
|
common.RedisEnabled = false
|
|
common.BatchUpdateEnabled = false
|
|
common.LogConsumeEnabled = true
|
|
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
panic("failed to get sql.DB: " + err.Error())
|
|
}
|
|
sqlDB.SetMaxOpenConns(1)
|
|
|
|
if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
|
|
panic("failed to migrate: " + err.Error())
|
|
}
|
|
|
|
os.Exit(m.Run())
|
|
}
|
|
|
|
func truncateTables(t *testing.T) {
|
|
t.Helper()
|
|
t.Cleanup(func() {
|
|
DB.Exec("DELETE FROM tasks")
|
|
DB.Exec("DELETE FROM users")
|
|
DB.Exec("DELETE FROM tokens")
|
|
DB.Exec("DELETE FROM logs")
|
|
DB.Exec("DELETE FROM channels")
|
|
})
|
|
}
|
|
|
|
func insertTask(t *testing.T, task *Task) {
|
|
t.Helper()
|
|
task.CreatedAt = time.Now().Unix()
|
|
task.UpdatedAt = time.Now().Unix()
|
|
require.NoError(t, DB.Create(task).Error)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Snapshot / Equal — pure logic tests (no DB)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestSnapshotEqual_Same(t *testing.T) {
|
|
s := taskSnapshot{
|
|
Status: TaskStatusInProgress,
|
|
Progress: "50%",
|
|
StartTime: 1000,
|
|
FinishTime: 0,
|
|
FailReason: "",
|
|
ResultURL: "",
|
|
Data: json.RawMessage(`{"key":"value"}`),
|
|
}
|
|
assert.True(t, s.Equal(s))
|
|
}
|
|
|
|
func TestSnapshotEqual_DifferentStatus(t *testing.T) {
|
|
a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
|
|
b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
|
|
assert.False(t, a.Equal(b))
|
|
}
|
|
|
|
func TestSnapshotEqual_DifferentProgress(t *testing.T) {
|
|
a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
|
|
b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
|
|
assert.False(t, a.Equal(b))
|
|
}
|
|
|
|
func TestSnapshotEqual_DifferentData(t *testing.T) {
|
|
a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
|
|
b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
|
|
assert.False(t, a.Equal(b))
|
|
}
|
|
|
|
func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
|
|
a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
|
|
b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
|
|
// bytes.Equal(nil, []byte{}) == true
|
|
assert.True(t, a.Equal(b))
|
|
}
|
|
|
|
func TestSnapshot_Roundtrip(t *testing.T) {
|
|
task := &Task{
|
|
Status: TaskStatusInProgress,
|
|
Progress: "42%",
|
|
StartTime: 1234,
|
|
FinishTime: 5678,
|
|
FailReason: "timeout",
|
|
PrivateData: TaskPrivateData{
|
|
ResultURL: "https://example.com/result.mp4",
|
|
},
|
|
Data: json.RawMessage(`{"model":"test-model"}`),
|
|
}
|
|
snap := task.Snapshot()
|
|
assert.Equal(t, task.Status, snap.Status)
|
|
assert.Equal(t, task.Progress, snap.Progress)
|
|
assert.Equal(t, task.StartTime, snap.StartTime)
|
|
assert.Equal(t, task.FinishTime, snap.FinishTime)
|
|
assert.Equal(t, task.FailReason, snap.FailReason)
|
|
assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
|
|
assert.JSONEq(t, string(task.Data), string(snap.Data))
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// UpdateWithStatus CAS — DB integration tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
func TestUpdateWithStatus_Win(t *testing.T) {
|
|
truncateTables(t)
|
|
|
|
task := &Task{
|
|
TaskID: "task_cas_win",
|
|
Status: TaskStatusInProgress,
|
|
Progress: "50%",
|
|
Data: json.RawMessage(`{}`),
|
|
}
|
|
insertTask(t, task)
|
|
|
|
task.Status = TaskStatusSuccess
|
|
task.Progress = "100%"
|
|
won, err := task.UpdateWithStatus(TaskStatusInProgress)
|
|
require.NoError(t, err)
|
|
assert.True(t, won)
|
|
|
|
var reloaded Task
|
|
require.NoError(t, DB.First(&reloaded, task.ID).Error)
|
|
assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
|
|
assert.Equal(t, "100%", reloaded.Progress)
|
|
}
|
|
|
|
func TestUpdateWithStatus_Lose(t *testing.T) {
|
|
truncateTables(t)
|
|
|
|
task := &Task{
|
|
TaskID: "task_cas_lose",
|
|
Status: TaskStatusFailure,
|
|
Data: json.RawMessage(`{}`),
|
|
}
|
|
insertTask(t, task)
|
|
|
|
task.Status = TaskStatusSuccess
|
|
won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
|
|
require.NoError(t, err)
|
|
assert.False(t, won)
|
|
|
|
var reloaded Task
|
|
require.NoError(t, DB.First(&reloaded, task.ID).Error)
|
|
assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
|
|
}
|
|
|
|
func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
|
|
truncateTables(t)
|
|
|
|
task := &Task{
|
|
TaskID: "task_cas_race",
|
|
Status: TaskStatusInProgress,
|
|
Quota: 1000,
|
|
Data: json.RawMessage(`{}`),
|
|
}
|
|
insertTask(t, task)
|
|
|
|
const goroutines = 5
|
|
wins := make([]bool, goroutines)
|
|
var wg sync.WaitGroup
|
|
wg.Add(goroutines)
|
|
|
|
for i := 0; i < goroutines; i++ {
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
t := &Task{}
|
|
*t = Task{
|
|
ID: task.ID,
|
|
TaskID: task.TaskID,
|
|
Status: TaskStatusSuccess,
|
|
Progress: "100%",
|
|
Quota: task.Quota,
|
|
Data: json.RawMessage(`{}`),
|
|
}
|
|
t.CreatedAt = task.CreatedAt
|
|
t.UpdatedAt = time.Now().Unix()
|
|
won, err := t.UpdateWithStatus(TaskStatusInProgress)
|
|
if err == nil {
|
|
wins[idx] = won
|
|
}
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
winCount := 0
|
|
for _, w := range wins {
|
|
if w {
|
|
winCount++
|
|
}
|
|
}
|
|
assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
|
|
}
|