mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 17:25:02 +00:00
Compare commits
8 Commits
v0.9.9-aws
...
v0.9.9-pat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09ff878d88 | ||
|
|
d4749ba388 | ||
|
|
1f2bdb1402 | ||
|
|
64a97092c9 | ||
|
|
69b87b5d8e | ||
|
|
bd4160793e | ||
|
|
82e21972ec | ||
|
|
dce00141ce |
2
.github/workflows/electron-build.yml
vendored
2
.github/workflows/electron-build.yml
vendored
@@ -4,6 +4,8 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- '*' # Triggers on version tags like v1.0.0
|
||||
- '!*-*' # Ignore pre-release tags like v1.0.0-beta
|
||||
- '!*-alpha*' # Ignore alpha tags like v1.0.0-alpha
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -115,6 +115,12 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
//return fmt.Errorf("task %s status is empty", taskId)
|
||||
taskResult = relaycommon.FailTaskInfo("upstream returned empty status")
|
||||
}
|
||||
|
||||
// 记录原本的状态,防止重复退款
|
||||
shouldRefund := false
|
||||
quota := task.Quota
|
||||
preStatus := task.Status
|
||||
|
||||
task.Status = model.TaskStatus(taskResult.Status)
|
||||
switch taskResult.Status {
|
||||
case model.TaskStatusSubmitted:
|
||||
@@ -225,7 +231,7 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
}
|
||||
case model.TaskStatusFailure:
|
||||
preStatus := task.Status
|
||||
logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task)
|
||||
task.Status = model.TaskStatusFailure
|
||||
task.Progress = "100%"
|
||||
if task.FinishTime == 0 {
|
||||
@@ -233,16 +239,10 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
}
|
||||
task.FailReason = taskResult.Reason
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason))
|
||||
quota := task.Quota
|
||||
taskResult.Progress = "100%"
|
||||
if quota != 0 {
|
||||
if preStatus != model.TaskStatusFailure {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
shouldRefund = true
|
||||
} else {
|
||||
logger.LogWarn(ctx, fmt.Sprintf("Task %s already in failure status, skip refund", task.TaskID))
|
||||
}
|
||||
@@ -257,6 +257,15 @@ func updateVideoSingleTask(ctx context.Context, adaptor channel.TaskAdaptor, cha
|
||||
common.SysLog("UpdateVideoTask task error: " + err.Error())
|
||||
}
|
||||
|
||||
if shouldRefund {
|
||||
// 任务失败且之前状态不是失败才退还额度,防止重复退还
|
||||
if err := model.IncreaseUserQuota(task.UserId, quota, false); err != nil {
|
||||
logger.LogWarn(ctx, "Failed to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("Video async task failed %s, refund %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -148,6 +148,10 @@ func (c *ClaudeMessage) SetStringContent(content string) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) SetContent(content any) {
|
||||
c.Content = content
|
||||
}
|
||||
|
||||
func (c *ClaudeMessage) ParseContent() ([]ClaudeMediaMessage, error) {
|
||||
return common.Any2Type[[]ClaudeMediaMessage](c.Content)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Requests []GeminiChatRequest `json:"requests,omitempty"` // For batch requests
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
|
||||
|
||||
@@ -153,5 +153,5 @@ func LogJson(ctx context.Context, msg string, obj any) {
|
||||
LogError(ctx, fmt.Sprintf("json marshal failed: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
LogInfo(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
LogDebug(ctx, fmt.Sprintf("%s | %s", msg, string(jsonStr)))
|
||||
}
|
||||
|
||||
15
model/log.go
15
model/log.go
@@ -39,14 +39,15 @@ type Log struct {
|
||||
Other string `json:"other"`
|
||||
}
|
||||
|
||||
// don't use iota, avoid change log type value
|
||||
const (
|
||||
LogTypeUnknown = iota
|
||||
LogTypeTopup
|
||||
LogTypeConsume
|
||||
LogTypeManage
|
||||
LogTypeSystem
|
||||
LogTypeRefund
|
||||
LogTypeError
|
||||
LogTypeUnknown = 0
|
||||
LogTypeTopup = 1
|
||||
LogTypeConsume = 2
|
||||
LogTypeManage = 3
|
||||
LogTypeSystem = 4
|
||||
LogTypeError = 5
|
||||
LogTypeRefund = 6
|
||||
)
|
||||
|
||||
func formatUserLogs(logs []*Log) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
"github.com/QuantumNous/new-api/relay/channel/claude"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||
"github.com/pkg/errors"
|
||||
@@ -38,6 +39,37 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
|
||||
for i, message := range request.Messages {
|
||||
updated := false
|
||||
if !message.IsStringContent() {
|
||||
content, err := message.ParseContent()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse message content")
|
||||
}
|
||||
for i2, mediaMessage := range content {
|
||||
if mediaMessage.Source != nil {
|
||||
if mediaMessage.Source.Type == "url" {
|
||||
fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
|
||||
}
|
||||
mediaMessage.Source.MediaType = fileData.MimeType
|
||||
mediaMessage.Source.Data = fileData.Base64Data
|
||||
mediaMessage.Source.Url = ""
|
||||
mediaMessage.Source.Type = "base64"
|
||||
content[i2] = mediaMessage
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
message.SetContent(content)
|
||||
}
|
||||
}
|
||||
if updated {
|
||||
request.Messages[i] = message
|
||||
}
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
@@ -56,7 +88,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey {
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
a.ClientMode = ClientModeApiKey
|
||||
awsSecret := strings.Split(info.ApiKey, "|")
|
||||
if len(awsSecret) != 2 {
|
||||
@@ -89,12 +121,11 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
// 原有的Claude模型处理逻辑
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, "failed to convert openai request to claude request")
|
||||
}
|
||||
info.UpstreamModelName = claudeReq.Model
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
|
||||
@@ -57,9 +57,11 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
a.AwsClient = awsCli
|
||||
|
||||
awsModelId := awsModelID(info.UpstreamModelName)
|
||||
println(info.UpstreamModelName)
|
||||
// 获取对应的AWS模型ID
|
||||
awsModelId := getAwsModelID(info.UpstreamModelName)
|
||||
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
awsRegionPrefix := getAwsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
@@ -119,7 +121,7 @@ func doAwsClientRequest(c *gin.Context, info *relaycommon.RelayInfo, a *Adaptor,
|
||||
}
|
||||
}
|
||||
|
||||
func awsRegionPrefix(awsRegionId string) string {
|
||||
func getAwsRegionPrefix(awsRegionId string) string {
|
||||
parts := strings.Split(awsRegionId, "-")
|
||||
regionPrefix := ""
|
||||
if len(parts) > 0 {
|
||||
@@ -141,11 +143,10 @@ func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
|
||||
return modelPrefix + "." + awsModelId
|
||||
}
|
||||
|
||||
func awsModelID(requestModel string) string {
|
||||
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelID
|
||||
func getAwsModelID(requestModel string) string {
|
||||
if awsModelIDName, ok := awsModelIDMap[requestModel]; ok {
|
||||
return awsModelIDName
|
||||
}
|
||||
|
||||
return requestModel
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/relay/helper"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
@@ -48,6 +49,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
@@ -240,6 +240,8 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
|
||||
req.SetModelName("models/" + info.UpstreamModelName)
|
||||
|
||||
adaptor := GetAdaptor(info.ApiType)
|
||||
if adaptor == nil {
|
||||
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
|
||||
@@ -264,6 +266,7 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
}
|
||||
logger.LogDebug(c, "Gemini embedding request body: "+string(jsonData))
|
||||
requestBody = bytes.NewReader(jsonData)
|
||||
|
||||
resp, err := adaptor.DoRequest(c, info, requestBody)
|
||||
|
||||
@@ -300,7 +300,7 @@ func GetAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiChatRequest, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(request.Contents) == 0 {
|
||||
if len(request.Contents) == 0 && len(request.Requests) == 0 {
|
||||
return nil, errors.New("contents is required")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user