From b244a06ca1d8dd56289ff4556416aa8921ccb185 Mon Sep 17 00:00:00 2001 From: feitianbubu Date: Thu, 2 Oct 2025 02:46:47 +0800 Subject: [PATCH] feat: add doubao video use quota by total token --- controller/task_video.go | 84 ++++++++++++++++++++++++++++ relay/channel/task/doubao/adaptor.go | 3 + relay/common/relay_info.go | 14 +++-- 3 files changed, 95 insertions(+), 6 deletions(-) diff --git a/controller/task_video.go b/controller/task_video.go index 73d5c39b1..8e8a5852d 100644 --- a/controller/task_video.go +++ b/controller/task_video.go @@ -13,6 +13,7 @@ import ( "one-api/relay" "one-api/relay/channel" relaycommon "one-api/relay/common" + "one-api/setting/ratio_setting" "time" ) @@ -120,6 +121,89 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha if !(len(taskResult.Url) > 5 && taskResult.Url[:5] == "data:") { task.FailReason = taskResult.Url } + + // 如果返回了 total_tokens 并且配置了模型倍率(非固定价格),则重新计费 + if taskResult.TotalTokens > 0 { + // 获取模型名称 + var taskData map[string]interface{} + if err := json.Unmarshal(task.Data, &taskData); err == nil { + if modelName, ok := taskData["model"].(string); ok && modelName != "" { + // 获取模型价格和倍率 + modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) + + // 只有配置了倍率(非固定价格)时才按 token 重新计费 + if hasRatioSetting && modelRatio > 0 { + // 获取用户和组的倍率信息 + user, err := model.GetUserById(task.UserId, false) + if err == nil { + groupRatio := ratio_setting.GetGroupRatio(user.Group) + userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(user.Group, user.Group) + + var finalGroupRatio float64 + if hasUserGroupRatio { + finalGroupRatio = userGroupRatio + } else { + finalGroupRatio = groupRatio + } + + // 计算实际应扣费额度: totalTokens * modelRatio * groupRatio + actualQuota := int(float64(taskResult.TotalTokens) * modelRatio * finalGroupRatio) + + // 计算差额 + preConsumedQuota := task.Quota + quotaDelta := actualQuota - preConsumedQuota + + if quotaDelta > 0 { + // 需要补扣费 + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后补扣费:%s(实际消耗:%s,预扣费:%s,tokens:%d)", + task.TaskID, + logger.LogQuota(quotaDelta), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + taskResult.TotalTokens, + )) + if err := model.DecreaseUserQuota(task.UserId, quotaDelta); err != nil { + logger.LogError(ctx, fmt.Sprintf("补扣费失败: %s", err.Error())) + } else { + model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + task.Quota = actualQuota // 更新任务记录的实际扣费额度 + + // 记录消费日志 + logContent := fmt.Sprintf("视频任务成功补扣费,模型倍率 %.2f,分组倍率 %.2f,tokens %d", + modelRatio, finalGroupRatio, taskResult.TotalTokens) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } else if quotaDelta < 0 { + // 需要退还多扣的费用 + refundQuota := -quotaDelta + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费后返还:%s(实际消耗:%s,预扣费:%s,tokens:%d)", + task.TaskID, + logger.LogQuota(refundQuota), + logger.LogQuota(actualQuota), + logger.LogQuota(preConsumedQuota), + taskResult.TotalTokens, + )) + if err := model.IncreaseUserQuota(task.UserId, refundQuota, false); err != nil { + logger.LogError(ctx, fmt.Sprintf("退还预扣费失败: %s", err.Error())) + } else { + task.Quota = actualQuota // 更新任务记录的实际扣费额度 + + // 记录退款日志 + logContent := fmt.Sprintf("视频任务成功退还多扣费用,模型倍率 %.2f,分组倍率 %.2f,tokens %d,退还 %s", + modelRatio, finalGroupRatio, taskResult.TotalTokens, logger.LogQuota(refundQuota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } + } else { + // quotaDelta == 0, 预扣费刚好准确 + logger.LogInfo(ctx, fmt.Sprintf("视频任务 %s 预扣费准确(%s,tokens:%d)", + task.TaskID, logger.LogQuota(actualQuota), taskResult.TotalTokens)) + } + } + } + } + } + } case model.TaskStatusFailure: task.Status = model.TaskStatusFailure task.Progress = "100%" diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index 9b40a249a..8cc1fa4f5 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -231,6 +231,9 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e taskResult.Status = model.TaskStatusSuccess taskResult.Progress = "100%" taskResult.Url = resTask.Content.VideoURL + // 解析 usage 信息用于按倍率计费 + taskResult.CompletionTokens = resTask.Usage.CompletionTokens + taskResult.TotalTokens = resTask.Usage.TotalTokens case "failed": taskResult.Status = model.TaskStatusFailure taskResult.Progress = "100%" diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index f4ffaee23..b2905c57b 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -500,10 +500,12 @@ func (t TaskSubmitReq) HasImage() bool { } type TaskInfo struct { - Code int `json:"code"` - TaskID string `json:"task_id"` - Status string `json:"status"` - Reason string `json:"reason,omitempty"` - Url string `json:"url,omitempty"` - Progress string `json:"progress,omitempty"` + Code int `json:"code"` + TaskID string `json:"task_id"` + Status string `json:"status"` + Reason string `json:"reason,omitempty"` + Url string `json:"url,omitempty"` + Progress string `json:"progress,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` // 用于按倍率计费 + TotalTokens int `json:"total_tokens,omitempty"` // 用于按倍率计费 }