Compare commits

..

4 Commits

Author SHA1 Message Date
creamlike1024
e2d3b46a3a fix: gemini batch embedding token not counted 2025-10-17 15:51:04 +08:00
CaIon
dd775167ab feat: support OpenAI channel type in sora relay adaptor 2025-10-17 13:53:15 +08:00
CaIon
43f2a8ac06 feat: add temporary TASK_PRICE_PATCH configuration to environment variables 2025-10-16 21:59:21 +08:00
CaIon
bcf93a2c05 fix: prevent refund on video task update error 2025-10-16 12:46:07 +08:00
6 changed files with 51 additions and 14 deletions

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/QuantumNous/new-api/constant"
@@ -118,4 +119,17 @@ func initConstantEnv() {
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
// 是否启用错误日志
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
if soraPatchStr != "" {
var taskPricePatches []string
soraPatches := strings.Split(soraPatchStr, ",")
for _, patch := range soraPatches {
trimmedPatch := strings.TrimSpace(patch)
if trimmedPatch != "" {
taskPricePatches = append(taskPricePatches, trimmedPatch)
}
}
constant.TaskPricePatches = taskPricePatches
}
}

View File

@@ -13,3 +13,6 @@ var NotifyLimitCount int
var NotificationLimitDurationMinute int
var GenerateDefaultToken bool
var ErrorLogEnabled bool
// temporary variable for sora patch, will be removed in future
var TaskPricePatches []string

View File

@@ -255,6 +255,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
}
if err := task.Update(); err != nil {
common.SysLog("UpdateVideoTask task error: " + err.Error())
shouldRefund = false
}
if shouldRefund {

View File

@@ -22,8 +22,10 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
case types.RelayFormatOpenAI:
request, err = GetAndValidateTextRequest(c, relayMode)
case types.RelayFormatGemini:
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
if strings.Contains(c.Request.URL.Path, ":embedContent") {
request, err = GetAndValidateGeminiEmbeddingRequest(c)
} else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
request, err = GetAndValidateGeminiBatchEmbeddingRequest(c)
} else {
request, err = GetAndValidateGeminiRequest(c)
}
@@ -319,3 +321,12 @@ func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingR
}
return request, nil
}
func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatchEmbeddingRequest, error) {
request := &dto.GeminiBatchEmbeddingRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
}
return request, nil
}

View File

@@ -139,7 +139,7 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
return &taskVidu.TaskAdaptor{}
case constant.ChannelTypeDoubaoVideo:
return &taskdoubao.TaskAdaptor{}
case constant.ChannelTypeSora:
case constant.ChannelTypeSora, constant.ChannelTypeOpenAI:
return &tasksora.TaskAdaptor{}
}
}

View File

@@ -72,10 +72,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
} else {
ratio = modelPrice * groupRatio
}
if len(info.PriceData.OtherRatios) > 0 {
for _, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
ratio *= ra
// FIXME: 临时修补,支持任务仅按次计费
if !common.StringsContains(constant.TaskPricePatches, modelName) {
if len(info.PriceData.OtherRatios) > 0 {
for _, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
ratio *= ra
}
}
}
}
@@ -153,15 +156,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
// gRatio = userGroupRatio
//}
logContent := fmt.Sprintf("操作 %s", info.Action)
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
// FIXME: 临时修补,支持任务仅按次计费
if common.StringsContains(constant.TaskPricePatches, modelName) {
logContent = fmt.Sprintf("%s按次计费", logContent)
} else {
if len(info.PriceData.OtherRatios) > 0 {
var contents []string
for key, ra := range info.PriceData.OtherRatios {
if 1.0 != ra {
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
if len(contents) > 0 {
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
}
}
other := make(map[string]interface{})