diff --git a/model/midjourney.go b/model/midjourney.go index 9867e8a96..e1a8d772b 100644 --- a/model/midjourney.go +++ b/model/midjourney.go @@ -160,8 +160,10 @@ func (midjourney *Midjourney) Update() error { // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). // Returns (true, nil) if this caller won the update, (false, nil) if // another process already moved the task out of fromStatus. +// UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). +// Uses Model().Select("*").Updates() to avoid GORM Save()'s INSERT fallback. func (midjourney *Midjourney) UpdateWithStatus(fromStatus string) (bool, error) { - result := DB.Where("status = ?", fromStatus).Save(midjourney) + result := DB.Model(midjourney).Where("status = ?", fromStatus).Select("*").Updates(midjourney) if result.Error != nil { return false, result.Error } diff --git a/model/task.go b/model/task.go index 4d1482f8b..0cf6bd47e 100644 --- a/model/task.go +++ b/model/task.go @@ -388,8 +388,12 @@ func (Task *Task) Update() error { // UpdateWithStatus performs a conditional UPDATE guarded by fromStatus (CAS). // Returns (true, nil) if this caller won the update, (false, nil) if // another process already moved the task out of fromStatus. +// +// Uses Model().Select("*").Updates() instead of Save() because GORM's Save +// falls back to INSERT ON CONFLICT when the WHERE-guarded UPDATE matches +// zero rows, which silently bypasses the CAS guard. func (t *Task) UpdateWithStatus(fromStatus TaskStatus) (bool, error) { - result := DB.Where("status = ?", fromStatus).Save(t) + result := DB.Model(t).Where("status = ?", fromStatus).Select("*").Updates(t) if result.Error != nil { return false, result.Error } diff --git a/model/task_cas_test.go b/model/task_cas_test.go new file mode 100644 index 000000000..3449c6d26 --- /dev/null +++ b/model/task_cas_test.go @@ -0,0 +1,217 @@ +package model + +import ( + "encoding/json" + "os" + "sync" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + DB = db + LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + if err := db.AutoMigrate(&Task{}, &User{}, &Token{}, &Log{}, &Channel{}); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +func truncateTables(t *testing.T) { + t.Helper() + t.Cleanup(func() { + DB.Exec("DELETE FROM tasks") + DB.Exec("DELETE FROM users") + DB.Exec("DELETE FROM tokens") + DB.Exec("DELETE FROM logs") + DB.Exec("DELETE FROM channels") + }) +} + +func insertTask(t *testing.T, task *Task) { + t.Helper() + task.CreatedAt = time.Now().Unix() + task.UpdatedAt = time.Now().Unix() + require.NoError(t, DB.Create(task).Error) +} + +// --------------------------------------------------------------------------- +// Snapshot / Equal — pure logic tests (no DB) +// --------------------------------------------------------------------------- + +func TestSnapshotEqual_Same(t *testing.T) { + s := taskSnapshot{ + Status: TaskStatusInProgress, + Progress: "50%", + StartTime: 1000, + FinishTime: 0, + FailReason: "", + ResultURL: "", + Data: json.RawMessage(`{"key":"value"}`), + } + assert.True(t, s.Equal(s)) +} + +func TestSnapshotEqual_DifferentStatus(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusSuccess, Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentProgress(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Progress: "30%", Data: json.RawMessage(`{}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Progress: "60%", Data: json.RawMessage(`{}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_DifferentData(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":1}`)} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage(`{"a":2}`)} + assert.False(t, a.Equal(b)) +} + +func TestSnapshotEqual_NilVsEmpty(t *testing.T) { + a := taskSnapshot{Status: TaskStatusInProgress, Data: nil} + b := taskSnapshot{Status: TaskStatusInProgress, Data: json.RawMessage{}} + // bytes.Equal(nil, []byte{}) == true + assert.True(t, a.Equal(b)) +} + +func TestSnapshot_Roundtrip(t *testing.T) { + task := &Task{ + Status: TaskStatusInProgress, + Progress: "42%", + StartTime: 1234, + FinishTime: 5678, + FailReason: "timeout", + PrivateData: TaskPrivateData{ + ResultURL: "https://example.com/result.mp4", + }, + Data: json.RawMessage(`{"model":"test-model"}`), + } + snap := task.Snapshot() + assert.Equal(t, task.Status, snap.Status) + assert.Equal(t, task.Progress, snap.Progress) + assert.Equal(t, task.StartTime, snap.StartTime) + assert.Equal(t, task.FinishTime, snap.FinishTime) + assert.Equal(t, task.FailReason, snap.FailReason) + assert.Equal(t, task.PrivateData.ResultURL, snap.ResultURL) + assert.JSONEq(t, string(task.Data), string(snap.Data)) +} + +// --------------------------------------------------------------------------- +// UpdateWithStatus CAS — DB integration tests +// --------------------------------------------------------------------------- + +func TestUpdateWithStatus_Win(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_win", + Status: TaskStatusInProgress, + Progress: "50%", + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + task.Progress = "100%" + won, err := task.UpdateWithStatus(TaskStatusInProgress) + require.NoError(t, err) + assert.True(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusSuccess, reloaded.Status) + assert.Equal(t, "100%", reloaded.Progress) +} + +func TestUpdateWithStatus_Lose(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_lose", + Status: TaskStatusFailure, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + task.Status = TaskStatusSuccess + won, err := task.UpdateWithStatus(TaskStatusInProgress) // wrong fromStatus + require.NoError(t, err) + assert.False(t, won) + + var reloaded Task + require.NoError(t, DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, TaskStatusFailure, reloaded.Status) // unchanged +} + +func TestUpdateWithStatus_ConcurrentWinner(t *testing.T) { + truncateTables(t) + + task := &Task{ + TaskID: "task_cas_race", + Status: TaskStatusInProgress, + Quota: 1000, + Data: json.RawMessage(`{}`), + } + insertTask(t, task) + + const goroutines = 5 + wins := make([]bool, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + t := &Task{} + *t = Task{ + ID: task.ID, + TaskID: task.TaskID, + Status: TaskStatusSuccess, + Progress: "100%", + Quota: task.Quota, + Data: json.RawMessage(`{}`), + } + t.CreatedAt = task.CreatedAt + t.UpdatedAt = time.Now().Unix() + won, err := t.UpdateWithStatus(TaskStatusInProgress) + if err == nil { + wins[idx] = won + } + }(i) + } + wg.Wait() + + winCount := 0 + for _, w := range wins { + if w { + winCount++ + } + } + assert.Equal(t, 1, winCount, "exactly one goroutine should win the CAS") +} diff --git a/service/task_billing_test.go b/service/task_billing_test.go new file mode 100644 index 000000000..6c2d231d5 --- /dev/null +++ b/service/task_billing_test.go @@ -0,0 +1,606 @@ +package service + +import ( + "context" + "encoding/json" + "os" + "testing" + "time" + + "github.com/QuantumNous/new-api/common" + "github.com/QuantumNous/new-api/model" + "github.com/glebarez/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +func TestMain(m *testing.M) { + db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) + if err != nil { + panic("failed to open test db: " + err.Error()) + } + sqlDB, err := db.DB() + if err != nil { + panic("failed to get sql.DB: " + err.Error()) + } + sqlDB.SetMaxOpenConns(1) + + model.DB = db + model.LOG_DB = db + + common.UsingSQLite = true + common.RedisEnabled = false + common.BatchUpdateEnabled = false + common.LogConsumeEnabled = true + + if err := db.AutoMigrate( + &model.Task{}, + &model.User{}, + &model.Token{}, + &model.Log{}, + &model.Channel{}, + &model.UserSubscription{}, + ); err != nil { + panic("failed to migrate: " + err.Error()) + } + + os.Exit(m.Run()) +} + +// --------------------------------------------------------------------------- +// Seed helpers +// --------------------------------------------------------------------------- + +func truncate(t *testing.T) { + t.Helper() + t.Cleanup(func() { + model.DB.Exec("DELETE FROM tasks") + model.DB.Exec("DELETE FROM users") + model.DB.Exec("DELETE FROM tokens") + model.DB.Exec("DELETE FROM logs") + model.DB.Exec("DELETE FROM channels") + model.DB.Exec("DELETE FROM user_subscriptions") + }) +} + +func seedUser(t *testing.T, id int, quota int) { + t.Helper() + user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled} + require.NoError(t, model.DB.Create(user).Error) +} + +func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) { + t.Helper() + token := &model.Token{ + Id: id, + UserId: userId, + Key: key, + Name: "test_token", + Status: common.TokenStatusEnabled, + RemainQuota: remainQuota, + UsedQuota: 0, + } + require.NoError(t, model.DB.Create(token).Error) +} + +func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) { + t.Helper() + sub := &model.UserSubscription{ + Id: id, + UserId: userId, + AmountTotal: amountTotal, + AmountUsed: amountUsed, + Status: "active", + StartTime: time.Now().Unix(), + EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(), + } + require.NoError(t, model.DB.Create(sub).Error) +} + +func seedChannel(t *testing.T, id int) { + t.Helper() + ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled} + require.NoError(t, model.DB.Create(ch).Error) +} + +func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task { + return &model.Task{ + TaskID: "task_" + time.Now().Format("150405.000"), + UserId: userId, + ChannelId: channelId, + Quota: quota, + Status: model.TaskStatus(model.TaskStatusInProgress), + Group: "default", + Data: json.RawMessage(`{}`), + CreatedAt: time.Now().Unix(), + UpdatedAt: time.Now().Unix(), + Properties: model.Properties{ + OriginModelName: "test-model", + }, + PrivateData: model.TaskPrivateData{ + BillingSource: billingSource, + SubscriptionId: subscriptionId, + TokenId: tokenId, + BillingContext: &model.TaskBillingContext{ + ModelPrice: 0.02, + GroupRatio: 1.0, + ModelName: "test-model", + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Read-back helpers +// --------------------------------------------------------------------------- + +func getUserQuota(t *testing.T, id int) int { + t.Helper() + var user model.User + require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error) + return user.Quota +} + +func getTokenRemainQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error) + return token.RemainQuota +} + +func getTokenUsedQuota(t *testing.T, id int) int { + t.Helper() + var token model.Token + require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error) + return token.UsedQuota +} + +func getSubscriptionUsed(t *testing.T, id int) int64 { + t.Helper() + var sub model.UserSubscription + require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error) + return sub.AmountUsed +} + +func getLastLog(t *testing.T) *model.Log { + t.Helper() + var log model.Log + err := model.LOG_DB.Order("id desc").First(&log).Error + if err != nil { + return nil + } + return &log +} + +func countLogs(t *testing.T) int64 { + t.Helper() + var count int64 + model.LOG_DB.Model(&model.Log{}).Count(&count) + return count +} + +// =========================================================================== +// RefundTaskQuota tests +// =========================================================================== + +func TestRefundTaskQuota_Wallet(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 1, 1, 1 + const initQuota, preConsumed = 10000, 3000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-test-key", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "task failed: upstream error") + + // User quota should increase by preConsumed + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Token remain_quota should increase, used_quota should decrease + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID)) + + // A refund log should be created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed, log.Quota) + assert.Equal(t, "test-model", log.ModelName) +} + +func TestRefundTaskQuota_Subscription(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 2, 2, 2, 1 + const preConsumed = 2000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RefundTaskQuota(ctx, task, "subscription task failed") + + // Subscription used should decrease by preConsumed + assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID)) + + // Token should also be refunded + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestRefundTaskQuota_ZeroQuota(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 3 + seedUser(t, userID, 5000) + + task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0) + + RefundTaskQuota(ctx, task, "zero quota task") + + // No change to user quota + assert.Equal(t, 5000, getUserQuota(t, userID)) + + // No log created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRefundTaskQuota_NoToken(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 4, 4 + const initQuota, preConsumed = 10000, 1500 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) // TokenId=0 + + RefundTaskQuota(ctx, task, "no token task failed") + + // User quota refunded + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + + // Log created + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// RecalculateTaskQuota tests +// =========================================================================== + +func TestRecalculate_PositiveDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 10, 10, 10 + const initQuota, preConsumed = 10000, 2000 + const actualQuota = 3000 // under-charged by 1000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should decrease by the delta (1000 additional charge) + assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID)) + + // Token should also be charged the delta + assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Consume (additional charge) + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeConsume, log.Type) + assert.Equal(t, actualQuota-preConsumed, log.Quota) +} + +func TestRecalculate_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 11, 11, 11 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged by 2000 + const tokenRemain = 5000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + + // User quota should increase by abs(delta) = 2000 (refund overpayment) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + + // Token should be refunded the difference + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota updated + assert.Equal(t, actualQuota, task.Quota) + + // Log type should be Refund + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, preConsumed-actualQuota, log.Quota) +} + +func TestRecalculate_ZeroDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 12 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, preConsumed, "exact match") + + // No change to user quota + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No log created (delta is zero) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_ActualQuotaZero(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID = 13 + const initQuota = 10000 + + seedUser(t, userID, initQuota) + + task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0) + + RecalculateTaskQuota(ctx, task, 0, "zero actual") + + // No change (early return) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID, subID = 14, 14, 14, 2 + const preConsumed = 5000 + const actualQuota = 2000 // over-charged by 3000 + const subTotal, subUsed int64 = 100000, 50000 + const tokenRemain = 8000 + + seedUser(t, userID, 0) + seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain) + seedChannel(t, channelID) + seedSubscription(t, subID, userID, subTotal, subUsed) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) + + RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge") + + // Subscription used should decrease by delta (refund 3000) + assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID)) + + // Token refunded + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + assert.Equal(t, actualQuota, task.Quota) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +// =========================================================================== +// CAS + Billing integration tests +// Simulates the flow in updateVideoSingleTask (service/task_polling.go) +// =========================================================================== + +// simulatePollBilling reproduces the CAS + billing logic from updateVideoSingleTask. +// It takes a persisted task (already in DB), applies the new status, and performs +// the conditional update + billing exactly as the polling loop does. +func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) { + snap := task.Snapshot() + + shouldRefund := false + shouldSettle := false + quota := task.Quota + + task.Status = newStatus + switch string(newStatus) { + case model.TaskStatusSuccess: + task.Progress = "100%" + task.FinishTime = 9999 + shouldSettle = true + case model.TaskStatusFailure: + task.Progress = "100%" + task.FinishTime = 9999 + task.FailReason = "upstream error" + if quota != 0 { + shouldRefund = true + } + default: + task.Progress = "50%" + } + + isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure) + if isDone && snap.Status != task.Status { + won, err := task.UpdateWithStatus(snap.Status) + if err != nil { + shouldRefund = false + shouldSettle = false + } else if !won { + shouldRefund = false + shouldSettle = false + } + } else if !snap.Equal(task.Snapshot()) { + _, _ = task.UpdateWithStatus(snap.Status) + } + + if shouldSettle && actualQuota > 0 { + RecalculateTaskQuota(ctx, task, actualQuota, "test settle") + } + if shouldRefund { + RefundTaskQuota(ctx, task, task.FailReason) + } +} + +func TestCASGuardedRefund_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 20, 20, 20 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS wins: task in DB should now be FAILURE + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status) + + // Refund should have happened + assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) + + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) +} + +func TestCASGuardedRefund_Lose(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 21, 21, 21 + const initQuota, preConsumed = 10000, 4000 + const tokenRemain = 6000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain) + seedChannel(t, channelID) + + // Create task with IN_PROGRESS in DB + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + // Simulate another process already transitioning to FAILURE + model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure) + + // Our process still has the old in-memory state (IN_PROGRESS) and tries to transition + // task.Status is still IN_PROGRESS in the snapshot + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) + + // CAS lost: user quota should NOT change (no double refund) + assert.Equal(t, initQuota, getUserQuota(t, userID)) + assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) + + // No billing log should be created + assert.Equal(t, int64(0), countLogs(t)) +} + +func TestCASGuardedSettle_Win(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, tokenID, channelID = 22, 22, 22 + const initQuota, preConsumed = 10000, 5000 + const actualQuota = 3000 // over-charged, should get partial refund + const tokenRemain = 8000 + + seedUser(t, userID, initQuota) + seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + require.NoError(t, model.DB.Create(task).Error) + + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota) + + // CAS wins: task should be SUCCESS + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status) + + // Settlement should refund the over-charge (5000 - 3000 = 2000 back to user) + assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + + // task.Quota should be updated to actualQuota + assert.Equal(t, actualQuota, task.Quota) +} + +func TestNonTerminalUpdate_NoBilling(t *testing.T) { + truncate(t) + ctx := context.Background() + + const userID, channelID = 23, 23 + const initQuota, preConsumed = 10000, 3000 + + seedUser(t, userID, initQuota) + seedChannel(t, channelID) + + task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) + task.Status = model.TaskStatus(model.TaskStatusInProgress) + task.Progress = "20%" + require.NoError(t, model.DB.Create(task).Error) + + // Simulate a non-terminal poll update (still IN_PROGRESS, progress changed) + simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0) + + // User quota should NOT change + assert.Equal(t, initQuota, getUserQuota(t, userID)) + + // No billing log + assert.Equal(t, int64(0), countLogs(t)) + + // Task progress should be updated in DB + var reloaded model.Task + require.NoError(t, model.DB.First(&reloaded, task.ID).Error) + assert.Equal(t, "50%", reloaded.Progress) +}