diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index f76011495..5f4b2714c 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/samber/lo" ) type HasPrompt interface { @@ -156,6 +157,34 @@ func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError { action = constant.TaskActionGenerate } info.Action = action + model := form.Value["model"][0] + if strings.HasPrefix(model, "sora-2") { + seconds := 4 + size := "720x1280" + if ss, ok := form.Value["seconds"]; ok { + sInt := common.String2Int(ss[0]) + if sInt > seconds { + seconds = common.String2Int(ss[0]) + } + } + if s, ok := form.Value["size"]; ok { + size = s[0] + } + + if model == "sora-2" && !lo.Contains([]string{"720x1280", "1280x720"}, size) { + return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) + } + if model == "sora-2-pro" && !lo.Contains([]string{"720x1280", "1280x720", "1792x1024", "1024x1792"}, size) { + return createTaskError(fmt.Errorf("sora-2 size is invalid"), "invalid_size", http.StatusBadRequest, true) + } + info.PriceData.OtherRatios = map[string]float64{ + "seconds": float64(seconds), + "size": 1, + } + if lo.Contains([]string{"1792x1024", "1024x1792"}, size) { + info.PriceData.OtherRatios["size"] = 1.666667 + } + } return nil } diff --git a/relay/helper/price.go b/relay/helper/price.go index c23c068b3..08ae78c4c 100644 --- a/relay/helper/price.go +++ b/relay/helper/price.go @@ -114,7 +114,7 @@ func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types. modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true) // 如果没有配置价格,则使用默认价格 if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName] + defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[info.OriginModelName] if !ok { modelPrice = 0.1 } else { diff --git a/relay/relay_task.go b/relay/relay_task.go index 0c4e9604c..057c64b07 100644 --- a/relay/relay_task.go +++ b/relay/relay_task.go @@ -54,7 +54,7 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. } modelPrice, success := ratio_setting.GetModelPrice(modelName, true) if !success { - defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName] + defaultPrice, ok := ratio_setting.GetDefaultModelPriceMap()[modelName] if !ok { modelPrice = 0.1 } else { @@ -71,6 +71,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. } else { ratio = modelPrice * groupRatio } + if len(info.PriceData.OtherRatios) > 0 { + for _, ra := range info.PriceData.OtherRatios { + if 1.0 != ra { + ratio *= ra + } + } + } userQuota, err := model.GetUserQuota(info.UserId, false) if err != nil { taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -144,6 +151,17 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto. gRatio = userGroupRatio } logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, info.Action) + if len(info.PriceData.OtherRatios) > 0 { + var contents []string + for key, ra := range info.PriceData.OtherRatios { + if 1.0 != ra { + contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) + } + } + if len(contents) > 0 { + logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) + } + } other := make(map[string]interface{}) other["model_price"] = modelPrice other["group_ratio"] = groupRatio diff --git a/setting/ratio_setting/model_ratio.go b/setting/ratio_setting/model_ratio.go index 5e55576fd..f8ffbaa8c 100644 --- a/setting/ratio_setting/model_ratio.go +++ b/setting/ratio_setting/model_ratio.go @@ -290,6 +290,8 @@ var defaultModelPrice = map[string]float64{ "mj_upscale": 0.05, "swap_face": 0.05, "mj_upload": 0.05, + "sora-2": 0.3, + "sora-2-pro": 0.5, } var defaultAudioRatio = map[string]float64{ @@ -452,6 +454,10 @@ func GetDefaultModelRatioMap() map[string]float64 { return defaultModelRatio } +func GetDefaultModelPriceMap() map[string]float64 { + return defaultModelPrice +} + func GetDefaultImageRatioMap() map[string]float64 { return defaultImageRatio } diff --git a/types/price_data.go b/types/price_data.go index ec7fcdfe9..dd5dad1aa 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -17,6 +17,7 @@ type PriceData struct { ImageRatio float64 AudioRatio float64 AudioCompletionRatio float64 + OtherRatios map[string]float64 UsePrice bool ShouldPreConsumedQuota int GroupRatioInfo GroupRatioInfo