mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 03:43:39 +00:00
refactor(task): enhance UpdateWithStatus for CAS updates and add integration tests
- Updated UpdateWithStatus method to use Model().Select("*").Updates() for conditional updates, preventing GORM's INSERT fallback.
- Introduced comprehensive integration tests for UpdateWithStatus, covering scenarios for winning and losing CAS updates, as well as concurrent updates.
- Added task_cas_test.go to validate the new behavior and ensure data integrity during concurrent state transitions.
This commit is contained in:
@@ -160,8 +160,10 @@ func (midjourney *Midjourney) Update() error {
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Returns (true, nil) if this caller won the update, (false, nil) if
|
||||
// another process already moved the task out of fromStatus.
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback.
|
||||
func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) {
|
||||
result := DB.Where("status = ?", fromStatus).Save(midjourney)
|
||||
result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
|
||||
@@ -388,8 +388,12 @@ func (Task *Task) Update() error {
|
||||
// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS).
|
||||
// Returns (true, nil) if this caller won the update, (false, nil) if
|
||||
// another process already moved the task out of fromStatus.
|
||||
//
|
||||
// Uses Model().Select("*").Updates() instead of Save() because GORM's Save
|
||||
// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches
|
||||
// zero rows, which silently bypasses the CAS guard.
|
||||
func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) {
|
||||
result := DB.Where("status = ?", fromStatus).Save(t)
|
||||
result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t)
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
|
||||
217
model/task_cas_test.go
Normal file
217
model/task_cas_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
panic("failed to open test db: " + err.Error())
|
||||
}
|
||||
DB = db
|
||||
LOG_DB = db
|
||||
|
||||
common.UsingSQLite = true
|
||||
common.RedisEnabled = false
|
||||
common.BatchUpdateEnabled = false
|
||||
common.LogConsumeEnabled = true
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
panic("failed to get sql.DB: " + err.Error())
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil {
|
||||
panic("failed to migrate: " + err.Error())
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func truncateTables(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Cleanup(func() {
|
||||
DB.Exec("DELETE FROM tasks")
|
||||
DB.Exec("DELETE FROM users")
|
||||
DB.Exec("DELETE FROM tokens")
|
||||
DB.Exec("DELETE FROM logs")
|
||||
DB.Exec("DELETE FROM channels")
|
||||
})
|
||||
}
|
||||
|
||||
func insertTask(t *testing.T, task *Task) {
|
||||
t.Helper()
|
||||
task.CreatedAt = time.Now().Unix()
|
||||
task.UpdatedAt = time.Now().Unix()
|
||||
require.NoError(t, DB.Create(task).Error)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Snapshot / Equal — pure logic tests (no DB)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSnapshotEqual_Same(t *testing.T) {
|
||||
s := taskSnapshot{
|
||||
Status: TaskStatusInProgress,
|
||||
Progress: "50%",
|
||||
StartTime: 1000,
|
||||
FinishTime: 0,
|
||||
FailReason: "",
|
||||
ResultURL: "",
|
||||
Data: json.RawMessage(`{"key":"value"}`),
|
||||
}
|
||||
assert.True(t, s.Equal(s))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_DifferentStatus(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)}
|
||||
b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_DifferentProgress(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)}
|
||||
b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_DifferentData(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)}
|
||||
b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshotEqual_NilVsEmpty(t *testing.T) {
|
||||
a := taskSnapshot{Status: TaskStatusInProgress, Data: nil}
|
||||
b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}}
|
||||
// bytes.Equal(nil, []byte{}) == true
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestSnapshot_Roundtrip(t *testing.T) {
|
||||
task := &Task{
|
||||
Status: TaskStatusInProgress,
|
||||
Progress: "42%",
|
||||
StartTime: 1234,
|
||||
FinishTime: 5678,
|
||||
FailReason: "timeout",
|
||||
PrivateData: TaskPrivateData{
|
||||
ResultURL: "https://example.com/result.mp4",
|
||||
},
|
||||
Data: json.RawMessage(`{"model":"test-model"}`),
|
||||
}
|
||||
snap := task.Snapshot()
|
||||
assert.Equal(t, task.Status, snap.Status)
|
||||
assert.Equal(t, task.Progress, snap.Progress)
|
||||
assert.Equal(t, task.StartTime, snap.StartTime)
|
||||
assert.Equal(t, task.FinishTime, snap.FinishTime)
|
||||
assert.Equal(t, task.FailReason, snap.FailReason)
|
||||
assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL)
|
||||
assert.JSONEq(t, string(task.Data), string(snap.Data))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// UpdateWithStatus CAS — DB integration tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestUpdateWithStatus_Win(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
task := &Task{
|
||||
TaskID: "task_cas_win",
|
||||
Status: TaskStatusInProgress,
|
||||
Progress: "50%",
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
insertTask(t, task)
|
||||
|
||||
task.Status = TaskStatusSuccess
|
||||
task.Progress = "100%"
|
||||
won, err := task.UpdateWithStatus(TaskStatusInProgress)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, won)
|
||||
|
||||
var reloaded Task
|
||||
require.NoError(t, DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, TaskStatusSuccess, reloaded.Status)
|
||||
assert.Equal(t, "100%", reloaded.Progress)
|
||||
}
|
||||
|
||||
func TestUpdateWithStatus_Lose(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
task := &Task{
|
||||
TaskID: "task_cas_lose",
|
||||
Status: TaskStatusFailure,
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
insertTask(t, task)
|
||||
|
||||
task.Status = TaskStatusSuccess
|
||||
won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus
|
||||
require.NoError(t, err)
|
||||
assert.False(t, won)
|
||||
|
||||
var reloaded Task
|
||||
require.NoError(t, DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged
|
||||
}
|
||||
|
||||
func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) {
|
||||
truncateTables(t)
|
||||
|
||||
task := &Task{
|
||||
TaskID: "task_cas_race",
|
||||
Status: TaskStatusInProgress,
|
||||
Quota: 1000,
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
insertTask(t, task)
|
||||
|
||||
const goroutines = 5
|
||||
wins := make([]bool, goroutines)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
t := &Task{}
|
||||
*t = Task{
|
||||
ID: task.ID,
|
||||
TaskID: task.TaskID,
|
||||
Status: TaskStatusSuccess,
|
||||
Progress: "100%",
|
||||
Quota: task.Quota,
|
||||
Data: json.RawMessage(`{}`),
|
||||
}
|
||||
t.CreatedAt = task.CreatedAt
|
||||
t.UpdatedAt = time.Now().Unix()
|
||||
won, err := t.UpdateWithStatus(TaskStatusInProgress)
|
||||
if err == nil {
|
||||
wins[idx] = won
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
winCount := 0
|
||||
for _, w := range wins {
|
||||
if w {
|
||||
winCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS")
|
||||
}
|
||||
606
service/task_billing_test.go
Normal file
606
service/task_billing_test.go
Normal file
@@ -0,0 +1,606 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
panic("failed to open test db: " + err.Error())
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
panic("failed to get sql.DB: " + err.Error())
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
model.DB = db
|
||||
model.LOG_DB = db
|
||||
|
||||
common.UsingSQLite = true
|
||||
common.RedisEnabled = false
|
||||
common.BatchUpdateEnabled = false
|
||||
common.LogConsumeEnabled = true
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&model.Task{},
|
||||
&model.User{},
|
||||
&model.Token{},
|
||||
&model.Log{},
|
||||
&model.Channel{},
|
||||
&model.UserSubscription{},
|
||||
); err != nil {
|
||||
panic("failed to migrate: " + err.Error())
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seed helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func truncate(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Cleanup(func() {
|
||||
model.DB.Exec("DELETE FROM tasks")
|
||||
model.DB.Exec("DELETE FROM users")
|
||||
model.DB.Exec("DELETE FROM tokens")
|
||||
model.DB.Exec("DELETE FROM logs")
|
||||
model.DB.Exec("DELETE FROM channels")
|
||||
model.DB.Exec("DELETE FROM user_subscriptions")
|
||||
})
|
||||
}
|
||||
|
||||
func seedUser(t *testing.T, id int, quota int) {
|
||||
t.Helper()
|
||||
user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled}
|
||||
require.NoError(t, model.DB.Create(user).Error)
|
||||
}
|
||||
|
||||
func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) {
|
||||
t.Helper()
|
||||
token := &model.Token{
|
||||
Id: id,
|
||||
UserId: userId,
|
||||
Key: key,
|
||||
Name: "test_token",
|
||||
Status: common.TokenStatusEnabled,
|
||||
RemainQuota: remainQuota,
|
||||
UsedQuota: 0,
|
||||
}
|
||||
require.NoError(t, model.DB.Create(token).Error)
|
||||
}
|
||||
|
||||
func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) {
|
||||
t.Helper()
|
||||
sub := &model.UserSubscription{
|
||||
Id: id,
|
||||
UserId: userId,
|
||||
AmountTotal: amountTotal,
|
||||
AmountUsed: amountUsed,
|
||||
Status: "active",
|
||||
StartTime: time.Now().Unix(),
|
||||
EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(),
|
||||
}
|
||||
require.NoError(t, model.DB.Create(sub).Error)
|
||||
}
|
||||
|
||||
func seedChannel(t *testing.T, id int) {
|
||||
t.Helper()
|
||||
ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled}
|
||||
require.NoError(t, model.DB.Create(ch).Error)
|
||||
}
|
||||
|
||||
func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task {
|
||||
return &model.Task{
|
||||
TaskID: "task_" + time.Now().Format("150405.000"),
|
||||
UserId: userId,
|
||||
ChannelId: channelId,
|
||||
Quota: quota,
|
||||
Status: model.TaskStatus(model.TaskStatusInProgress),
|
||||
Group: "default",
|
||||
Data: json.RawMessage(`{}`),
|
||||
CreatedAt: time.Now().Unix(),
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
Properties: model.Properties{
|
||||
OriginModelName: "test-model",
|
||||
},
|
||||
PrivateData: model.TaskPrivateData{
|
||||
BillingSource: billingSource,
|
||||
SubscriptionId: subscriptionId,
|
||||
TokenId: tokenId,
|
||||
BillingContext: &model.TaskBillingContext{
|
||||
ModelPrice: 0.02,
|
||||
GroupRatio: 1.0,
|
||||
ModelName: "test-model",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Read-back helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func getUserQuota(t *testing.T, id int) int {
|
||||
t.Helper()
|
||||
var user model.User
|
||||
require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error)
|
||||
return user.Quota
|
||||
}
|
||||
|
||||
func getTokenRemainQuota(t *testing.T, id int) int {
|
||||
t.Helper()
|
||||
var token model.Token
|
||||
require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error)
|
||||
return token.RemainQuota
|
||||
}
|
||||
|
||||
func getTokenUsedQuota(t *testing.T, id int) int {
|
||||
t.Helper()
|
||||
var token model.Token
|
||||
require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error)
|
||||
return token.UsedQuota
|
||||
}
|
||||
|
||||
func getSubscriptionUsed(t *testing.T, id int) int64 {
|
||||
t.Helper()
|
||||
var sub model.UserSubscription
|
||||
require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error)
|
||||
return sub.AmountUsed
|
||||
}
|
||||
|
||||
func getLastLog(t *testing.T) *model.Log {
|
||||
t.Helper()
|
||||
var log model.Log
|
||||
err := model.LOG_DB.Order("id desc").First(&log).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &log
|
||||
}
|
||||
|
||||
func countLogs(t *testing.T) int64 {
|
||||
t.Helper()
|
||||
var count int64
|
||||
model.LOG_DB.Model(&model.Log{}).Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// RefundTaskQuota tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestRefundTaskQuota_Wallet(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 1, 1, 1
|
||||
const initQuota, preConsumed = 10000, 3000
|
||||
const tokenRemain = 5000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-test-key", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
|
||||
RefundTaskQuota(ctx, task, "task failed: upstream error")
|
||||
|
||||
// User quota should increase by preConsumed
|
||||
assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
|
||||
|
||||
// Token remain_quota should increase, used_quota should decrease
|
||||
assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
|
||||
assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID))
|
||||
|
||||
// A refund log should be created
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
assert.Equal(t, preConsumed, log.Quota)
|
||||
assert.Equal(t, "test-model", log.ModelName)
|
||||
}
|
||||
|
||||
func TestRefundTaskQuota_Subscription(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID, subID = 2, 2, 2, 1
|
||||
const preConsumed = 2000
|
||||
const subTotal, subUsed int64 = 100000, 50000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, 0)
|
||||
seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
seedSubscription(t, subID, userID, subTotal, subUsed)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
|
||||
|
||||
RefundTaskQuota(ctx, task, "subscription task failed")
|
||||
|
||||
// Subscription used should decrease by preConsumed
|
||||
assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID))
|
||||
|
||||
// Token should also be refunded
|
||||
assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
func TestRefundTaskQuota_ZeroQuota(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = 3
|
||||
seedUser(t, userID, 5000)
|
||||
|
||||
task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0)
|
||||
|
||||
RefundTaskQuota(ctx, task, "zero quota task")
|
||||
|
||||
// No change to user quota
|
||||
assert.Equal(t, 5000, getUserQuota(t, userID))
|
||||
|
||||
// No log created
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestRefundTaskQuota_NoToken(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, channelID = 4, 4
|
||||
const initQuota, preConsumed = 10000, 1500
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0
|
||||
|
||||
RefundTaskQuota(ctx, task, "no token task failed")
|
||||
|
||||
// User quota refunded
|
||||
assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
|
||||
|
||||
// Log created
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// RecalculateTaskQuota tests
|
||||
// ===========================================================================
|
||||
|
||||
func TestRecalculate_PositiveDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 10, 10, 10
|
||||
const initQuota, preConsumed = 10000, 2000
|
||||
const actualQuota = 3000 // under-charged by 1000
|
||||
const tokenRemain = 5000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
|
||||
|
||||
// User quota should decrease by the delta (1000 additional charge)
|
||||
assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID))
|
||||
|
||||
// Token should also be charged the delta
|
||||
assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// task.Quota should be updated to actualQuota
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
|
||||
// Log type should be Consume (additional charge)
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeConsume, log.Type)
|
||||
assert.Equal(t, actualQuota-preConsumed, log.Quota)
|
||||
}
|
||||
|
||||
func TestRecalculate_NegativeDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 11, 11, 11
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const actualQuota = 3000 // over-charged by 2000
|
||||
const tokenRemain = 5000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment")
|
||||
|
||||
// User quota should increase by abs(delta) = 2000 (refund overpayment)
|
||||
assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
|
||||
|
||||
// Token should be refunded the difference
|
||||
assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// task.Quota updated
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
|
||||
// Log type should be Refund
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
assert.Equal(t, preConsumed-actualQuota, log.Quota)
|
||||
}
|
||||
|
||||
func TestRecalculate_ZeroDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = 12
|
||||
const initQuota, preConsumed = 10000, 3000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
|
||||
task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, preConsumed, "exact match")
|
||||
|
||||
// No change to user quota
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
|
||||
// No log created (delta is zero)
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestRecalculate_ActualQuotaZero(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID = 13
|
||||
const initQuota = 10000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
|
||||
task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, 0, "zero actual")
|
||||
|
||||
// No change (early return)
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestRecalculate_Subscription_NegativeDelta(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID, subID = 14, 14, 14, 2
|
||||
const preConsumed = 5000
|
||||
const actualQuota = 2000 // over-charged by 3000
|
||||
const subTotal, subUsed int64 = 100000, 50000
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, 0)
|
||||
seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
seedSubscription(t, subID, userID, subTotal, subUsed)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID)
|
||||
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge")
|
||||
|
||||
// Subscription used should decrease by delta (refund 3000)
|
||||
assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID))
|
||||
|
||||
// Token refunded
|
||||
assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
// ===========================================================================
|
||||
// CAS + Billing integration tests
|
||||
// Simulates the flow in updateVideoSingleTask (service/task_polling.go)
|
||||
// ===========================================================================
|
||||
|
||||
// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask.
|
||||
// It takes a persisted task (already in DB), applies the new status, and performs
|
||||
// the conditional update + billing exactly as the polling loop does.
|
||||
func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) {
|
||||
snap := task.Snapshot()
|
||||
|
||||
shouldRefund := false
|
||||
shouldSettle := false
|
||||
quota := task.Quota
|
||||
|
||||
task.Status = newStatus
|
||||
switch string(newStatus) {
|
||||
case model.TaskStatusSuccess:
|
||||
task.Progress = "100%"
|
||||
task.FinishTime = 9999
|
||||
shouldSettle = true
|
||||
case model.TaskStatusFailure:
|
||||
task.Progress = "100%"
|
||||
task.FinishTime = 9999
|
||||
task.FailReason = "upstream error"
|
||||
if quota != 0 {
|
||||
shouldRefund = true
|
||||
}
|
||||
default:
|
||||
task.Progress = "50%"
|
||||
}
|
||||
|
||||
isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure)
|
||||
if isDone && snap.Status != task.Status {
|
||||
won, err := task.UpdateWithStatus(snap.Status)
|
||||
if err != nil {
|
||||
shouldRefund = false
|
||||
shouldSettle = false
|
||||
} else if !won {
|
||||
shouldRefund = false
|
||||
shouldSettle = false
|
||||
}
|
||||
} else if !snap.Equal(task.Snapshot()) {
|
||||
_, _ = task.UpdateWithStatus(snap.Status)
|
||||
}
|
||||
|
||||
if shouldSettle && actualQuota > 0 {
|
||||
RecalculateTaskQuota(ctx, task, actualQuota, "test settle")
|
||||
}
|
||||
if shouldRefund {
|
||||
RefundTaskQuota(ctx, task, task.FailReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCASGuardedRefund_Win(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 20, 20, 20
|
||||
const initQuota, preConsumed = 10000, 4000
|
||||
const tokenRemain = 6000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
|
||||
|
||||
// CAS wins: task in DB should now be FAILURE
|
||||
var reloaded model.Task
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status)
|
||||
|
||||
// Refund should have happened
|
||||
assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID))
|
||||
|
||||
log := getLastLog(t)
|
||||
require.NotNil(t, log)
|
||||
assert.Equal(t, model.LogTypeRefund, log.Type)
|
||||
}
|
||||
|
||||
func TestCASGuardedRefund_Lose(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 21, 21, 21
|
||||
const initQuota, preConsumed = 10000, 4000
|
||||
const tokenRemain = 6000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
// Create task with IN_PROGRESS in DB
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
// Simulate another process already transitioning to FAILURE
|
||||
model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure)
|
||||
|
||||
// Our process still has the old in-memory state (IN_PROGRESS) and tries to transition
|
||||
// task.Status is still IN_PROGRESS in the snapshot
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0)
|
||||
|
||||
// CAS lost: user quota should NOT change (no double refund)
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// No billing log should be created
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
}
|
||||
|
||||
func TestCASGuardedSettle_Win(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, tokenID, channelID = 22, 22, 22
|
||||
const initQuota, preConsumed = 10000, 5000
|
||||
const actualQuota = 3000 // over-charged, should get partial refund
|
||||
const tokenRemain = 8000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota)
|
||||
|
||||
// CAS wins: task should be SUCCESS
|
||||
var reloaded model.Task
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status)
|
||||
|
||||
// Settlement should refund the over-charge (5000 - 3000 = 2000 back to user)
|
||||
assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID))
|
||||
assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID))
|
||||
|
||||
// task.Quota should be updated to actualQuota
|
||||
assert.Equal(t, actualQuota, task.Quota)
|
||||
}
|
||||
|
||||
func TestNonTerminalUpdate_NoBilling(t *testing.T) {
|
||||
truncate(t)
|
||||
ctx := context.Background()
|
||||
|
||||
const userID, channelID = 23, 23
|
||||
const initQuota, preConsumed = 10000, 3000
|
||||
|
||||
seedUser(t, userID, initQuota)
|
||||
seedChannel(t, channelID)
|
||||
|
||||
task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0)
|
||||
task.Status = model.TaskStatus(model.TaskStatusInProgress)
|
||||
task.Progress = "20%"
|
||||
require.NoError(t, model.DB.Create(task).Error)
|
||||
|
||||
// Simulate a non-terminal poll update (still IN_PROGRESS, progress changed)
|
||||
simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0)
|
||||
|
||||
// User quota should NOT change
|
||||
assert.Equal(t, initQuota, getUserQuota(t, userID))
|
||||
|
||||
// No billing log
|
||||
assert.Equal(t, int64(0), countLogs(t))
|
||||
|
||||
// Task progress should be updated in DB
|
||||
var reloaded model.Task
|
||||
require.NoError(t, model.DB.First(&reloaded, task.ID).Error)
|
||||
assert.Equal(t, "50%", reloaded.Progress)
|
||||
}
|
||||
Reference in New Issue
Block a user