mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-04-02 11:56:10 +00:00
Compare commits
6 Commits
v0.9.9-pat
...
fix-gemini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2d3b46a3a | ||
|
|
dd775167ab | ||
|
|
43f2a8ac06 | ||
|
|
bcf93a2c05 | ||
|
|
09ff878d88 | ||
|
|
d4749ba388 |
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -118,4 +119,17 @@ func initConstantEnv() {
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
|
||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||
if soraPatchStr != "" {
|
||||
var taskPricePatches []string
|
||||
soraPatches := strings.Split(soraPatchStr, ",")
|
||||
for _, patch := range soraPatches {
|
||||
trimmedPatch := strings.TrimSpace(patch)
|
||||
if trimmedPatch != "" {
|
||||
taskPricePatches = append(taskPricePatches, trimmedPatch)
|
||||
}
|
||||
}
|
||||
constant.TaskPricePatches = taskPricePatches
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,3 +13,6 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
|
||||
// temporary variable for sora patch, will be removed in future
|
||||
var TaskPricePatches []string
|
||||
|
||||
@@ -115,6 +115,12 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
@@ -225,7 +231,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
preStatus := task.Status
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
@@ -233,16 +239,10 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
@@ -255,6 +255,16 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
shouldRefund = false
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -153,5 +153,5 @@ func LogJson(ctx context.Context, msg string, obj any) {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
LogDebug(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
a.ClientMode = ClientModeApiKey
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 2 {
|
||||
|
||||
@@ -57,9 +57,11 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
a.AwsClient = awsCli
|
||||
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
println(info.UpstreamModelName)
|
||||
// 获取对应的AWS模型ID
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -119,7 +121,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
}
|
||||
|
||||
func awsRegionPrefix(awsRegionId string) string {
|
||||
func getAwsRegionPrefix(awsRegionId string) string {
|
||||
parts := strings.Split(awsRegionId, "-")
|
||||
regionPrefix := ""
|
||||
if len(parts) > 0 {
|
||||
@@ -141,11 +143,10 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) string {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID
|
||||
func getAwsModelID(requestModel string) string {
|
||||
if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelIDName
|
||||
}
|
||||
|
||||
return requestModel
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,10 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
|
||||
case types.RelayFormatOpenAI:
|
||||
request, err = GetAndValidateTextRequest(c, relayMode)
|
||||
case types.RelayFormatGemini:
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") {
|
||||
request, err = GetAndValidateGeminiEmbeddingRequest(c)
|
||||
} else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
request, err = GetAndValidateGeminiBatchEmbeddingRequest(c)
|
||||
} else {
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
}
|
||||
@@ -319,3 +321,12 @@ func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingR
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatchEmbeddingRequest, error) {
|
||||
request := &dto.GeminiBatchEmbeddingRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
case constant.ChannelTypeDoubaoVideo:
|
||||
return &taskdoubao.TaskAdaptor{}
|
||||
case constant.ChannelTypeSora:
|
||||
case constant.ChannelTypeSora, constant.ChannelTypeOpenAI:
|
||||
return &tasksora.TaskAdaptor{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,10 +72,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -153,15 +156,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
// gRatio = userGroupRatio
|
||||
//}
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
|
||||
Reference in New Issue
Block a user