mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-31 07:26:05 +00:00
Compare commits
9 Commits
v0.9.9
...
fix-gemini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e2d3b46a3a | ||
|
|
dd775167ab | ||
|
|
43f2a8ac06 | ||
|
|
bcf93a2c05 | ||
|
|
09ff878d88 | ||
|
|
d4749ba388 | ||
|
|
1f2bdb1402 | ||
|
|
64a97092c9 | ||
|
|
69b87b5d8e |
2
.github/workflows/electron-build.yml
vendored
2
.github/workflows/electron-build.yml
vendored
@@ -4,6 +4,8 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*' # Triggers on version tags like v1.0.0
|
||||
- '!*-*' # Ignore pre-release tags like v1.0.0-beta
|
||||
- '!*-alpha*' # Ignore alpha tags like v1.0.0-alpha
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
@@ -118,4 +119,17 @@ func initConstantEnv() {
|
||||
constant.GenerateDefaultToken = GetEnvOrDefaultBool("GENERATE_DEFAULT_TOKEN", false)
|
||||
// 是否启用错误日志
|
||||
constant.ErrorLogEnabled = GetEnvOrDefaultBool("ERROR_LOG_ENABLED", false)
|
||||
|
||||
soraPatchStr := GetEnvOrDefaultString("TASK_PRICE_PATCH", "")
|
||||
if soraPatchStr != "" {
|
||||
var taskPricePatches []string
|
||||
soraPatches := strings.Split(soraPatchStr, ",")
|
||||
for _, patch := range soraPatches {
|
||||
trimmedPatch := strings.TrimSpace(patch)
|
||||
if trimmedPatch != "" {
|
||||
taskPricePatches = append(taskPricePatches, trimmedPatch)
|
||||
}
|
||||
}
|
||||
constant.TaskPricePatches = taskPricePatches
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,3 +13,6 @@ var NotifyLimitCount int
|
||||
var NotificationLimitDurationMinute int
|
||||
var GenerateDefaultToken bool
|
||||
var ErrorLogEnabled bool
|
||||
|
||||
// temporary variable for sora patch, will be removed in future
|
||||
var TaskPricePatches []string
|
||||
|
||||
@@ -115,6 +115,12 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
@@ -225,7 +231,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
preStatus := task.Status
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
@@ -233,16 +239,10 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
@@ -255,6 +255,16 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
if err := task.Update(); err != nil {
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
shouldRefund = false
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Requests []GeminiChatRequest `json:"requests,omitempty"` // For batch requests
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
|
||||
@@ -153,5 +153,5 @@ func LogJson(ctx context.Context, msg string, obj any) {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
LogDebug(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
|
||||
15
model/log.go
15
model/log.go
@@ -39,14 +39,15 @@ type Log struct {
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
// don't use iota, avoid change log type value
|
||||
const (
|
||||
LogTypeUnknown = iota
|
||||
LogTypeTopup
|
||||
LogTypeConsume
|
||||
LogTypeManage
|
||||
LogTypeSystem
|
||||
LogTypeError
|
||||
LogTypeRefund
|
||||
LogTypeUnknown = 0
|
||||
LogTypeTopup = 1
|
||||
LogTypeConsume = 2
|
||||
LogTypeManage = 3
|
||||
LogTypeSystem = 4
|
||||
LogTypeError = 5
|
||||
LogTypeRefund = 6
|
||||
)
|
||||
|
||||
func formatUserLogs(logs []*Log) {
|
||||
|
||||
@@ -88,7 +88,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
a.ClientMode = ClientModeApiKey
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 2 {
|
||||
|
||||
@@ -57,9 +57,11 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
a.AwsClient = awsCli
|
||||
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
println(info.UpstreamModelName)
|
||||
// 获取对应的AWS模型ID
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -119,7 +121,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
}
|
||||
|
||||
func awsRegionPrefix(awsRegionId string) string {
|
||||
func getAwsRegionPrefix(awsRegionId string) string {
|
||||
parts := strings.Split(awsRegionId, "-")
|
||||
regionPrefix := ""
|
||||
if len(parts) > 0 {
|
||||
@@ -141,11 +143,10 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) string {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID
|
||||
func getAwsModelID(requestModel string) string {
|
||||
if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelIDName
|
||||
}
|
||||
|
||||
return requestModel
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
@@ -48,6 +49,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
@@ -240,6 +240,8 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
req.SetModelName("models/" + info.UpstreamModelName)
|
||||
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
@@ -264,6 +266,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
@@ -22,8 +22,10 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
|
||||
case types.RelayFormatOpenAI:
|
||||
request, err = GetAndValidateTextRequest(c, relayMode)
|
||||
case types.RelayFormatGemini:
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") || strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
if strings.Contains(c.Request.URL.Path, ":embedContent") {
|
||||
request, err = GetAndValidateGeminiEmbeddingRequest(c)
|
||||
} else if strings.Contains(c.Request.URL.Path, ":batchEmbedContents") {
|
||||
request, err = GetAndValidateGeminiBatchEmbeddingRequest(c)
|
||||
} else {
|
||||
request, err = GetAndValidateGeminiRequest(c)
|
||||
}
|
||||
@@ -300,7 +302,7 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(request.Contents) == 0 {
|
||||
if len(request.Contents) == 0 && len(request.Requests) == 0 {
|
||||
return nil, errors.New("contents is required")
|
||||
}
|
||||
|
||||
@@ -319,3 +321,12 @@ func GetAndValidateGeminiEmbeddingRequest(c *gin.Context) (*dto.GeminiEmbeddingR
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func GetAndValidateGeminiBatchEmbeddingRequest(c *gin.Context) (*dto.GeminiBatchEmbeddingRequest, error) {
|
||||
request := &dto.GeminiBatchEmbeddingRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
|
||||
return &taskVidu.TaskAdaptor{}
|
||||
case constant.ChannelTypeDoubaoVideo:
|
||||
return &taskdoubao.TaskAdaptor{}
|
||||
case constant.ChannelTypeSora:
|
||||
case constant.ChannelTypeSora, constant.ChannelTypeOpenAI:
|
||||
return &tasksora.TaskAdaptor{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,10 +72,13 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
} else {
|
||||
ratio = modelPrice * groupRatio
|
||||
}
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if !common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
for _, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
ratio *= ra
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -153,15 +156,20 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
|
||||
// gRatio = userGroupRatio
|
||||
//}
|
||||
logContent := fmt.Sprintf("操作 %s", info.Action)
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
// FIXME: 临时修补,支持任务仅按次计费
|
||||
if common.StringsContains(constant.TaskPricePatches, modelName) {
|
||||
logContent = fmt.Sprintf("%s,按次计费", logContent)
|
||||
} else {
|
||||
if len(info.PriceData.OtherRatios) > 0 {
|
||||
var contents []string
|
||||
for key, ra := range info.PriceData.OtherRatios {
|
||||
if 1.0 != ra {
|
||||
contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", "))
|
||||
}
|
||||
}
|
||||
other := make(map[string]interface{})
|
||||
|
||||
Reference in New Issue
Block a user