From 48d358faecd0859a53c50bea71e1fc3ee46739cd Mon Sep 17 00:00:00 2001 From: CaIon Date: Mon, 29 Dec 2025 22:58:32 +0800 Subject: [PATCH] =?UTF-8?q?feat(adaptor):=20=E6=96=B0=E9=80=82=E9=85=8D?= =?UTF-8?q?=E7=99=BE=E7=82=BC=E5=A4=9A=E7=A7=8D=E5=9B=BE=E7=89=87=E7=94=9F?= =?UTF-8?q?=E6=88=90=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - wan2.6系列生图与编辑,适配多图生成计费 - wan2.5系列生图与编辑 - z-image-turbo生图,适配prompt_extend计费 --- controller/token.go | 37 +++++++- dto/openai_image.go | 6 +- dto/openai_request.go | 5 +- main.go | 2 + relay/audio_handler.go | 2 +- relay/channel/ali/adaptor.go | 60 +++++++++--- relay/channel/ali/dto.go | 116 ++++++++++++++++++---- relay/channel/ali/image.go | 169 ++++++++++++++++----------------- relay/channel/ali/image_wan.go | 14 ++- relay/compatible_handler.go | 46 +++++---- relay/embedding_handler.go | 2 +- relay/gemini_handler.go | 4 +- relay/image_handler.go | 12 ++- relay/rerank_handler.go | 2 +- relay/responses_handler.go | 2 +- types/price_data.go | 12 ++- 16 files changed, 336 insertions(+), 155 deletions(-) diff --git a/controller/token.go b/controller/token.go index efefea0eb..c5dc5ec42 100644 --- a/controller/token.go +++ b/controller/token.go @@ -1,6 +1,7 @@ package controller import ( + "fmt" "net/http" "strconv" "strings" @@ -149,6 +150,24 @@ func AddToken(c *gin.Context) { }) return } + // 非无限额度时,检查额度值是否超出有效范围 + if !token.UnlimitedQuota { + if token.RemainQuota < 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "额度值不能为负数", + }) + return + } + maxQuotaValue := int((1000000000 * common.QuotaPerUnit)) + if token.RemainQuota > maxQuotaValue { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue), + }) + return + } + } key, err := common.GenerateKey() if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -216,6 +235,23 @@ func UpdateToken(c *gin.Context) { }) return } + if !token.UnlimitedQuota { + if token.RemainQuota < 0 { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "额度值不能为负数", + }) + return + } + maxQuotaValue := int((1000000000 * common.QuotaPerUnit)) + if token.RemainQuota > maxQuotaValue { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue), + }) + return + } + } cleanToken, err := model.GetTokenByIds(token.Id, userId) if err != nil { common.ApiError(c, err) @@ -261,7 +297,6 @@ func UpdateToken(c *gin.Context) { "message": "", "data": cleanToken, }) - return } type TokenBatch struct { diff --git a/dto/openai_image.go b/dto/openai_image.go index 130d1dde8..a19bb69d6 100644 --- a/dto/openai_image.go +++ b/dto/openai_image.go @@ -167,9 +167,9 @@ func (i *ImageRequest) SetModelName(modelName string) { } type ImageResponse struct { - Data []ImageData `json:"data"` - Created int64 `json:"created"` - Extra any `json:"extra,omitempty"` + Data []ImageData `json:"data"` + Created int64 `json:"created"` + Metadata json.RawMessage `json:"metadata,omitempty"` } type ImageData struct { Url string `json:"url"` diff --git a/dto/openai_request.go b/dto/openai_request.go index 5415e67f3..232a1ae1b 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -23,6 +23,8 @@ type FormatJsonSchema struct { Strict json.RawMessage `json:"strict,omitempty"` } +// GeneralOpenAIRequest represents a general request structure for OpenAI-compatible APIs. +// 参数增加规范:无引用的参数必须使用json.RawMessage类型,并添加omitempty标签 type GeneralOpenAIRequest struct { Model string `json:"model,omitempty"` Messages []Message `json:"messages,omitempty"` @@ -82,8 +84,9 @@ type GeneralOpenAIRequest struct { Reasoning json.RawMessage `json:"reasoning,omitempty"` // Ali Qwen Params VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"` - EnableThinking any `json:"enable_thinking,omitempty"` + EnableThinking json.RawMessage `json:"enable_thinking,omitempty"` ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"` + EnableSearch json.RawMessage `json:"enable_search,omitempty"` // ollama Params Think json.RawMessage `json:"think,omitempty"` // baidu v2 diff --git a/main.go b/main.go index 8484257bf..4c0fc8c6e 100644 --- a/main.go +++ b/main.go @@ -188,6 +188,7 @@ func InjectUmamiAnalytics() { analyticsInjectBuilder.WriteString(umamiSiteID) analyticsInjectBuilder.WriteString("\">") } + analyticsInjectBuilder.WriteString("\n") analyticsInject := analyticsInjectBuilder.String() indexPage = bytes.ReplaceAll(indexPage, []byte("\n"), []byte(analyticsInject)) } @@ -209,6 +210,7 @@ func InjectGoogleAnalytics() { analyticsInjectBuilder.WriteString("');") analyticsInjectBuilder.WriteString("") } + analyticsInjectBuilder.WriteString("\n") analyticsInject := analyticsInjectBuilder.String() indexPage = bytes.ReplaceAll(indexPage, []byte("\n"), []byte(analyticsInject)) } diff --git a/relay/audio_handler.go b/relay/audio_handler.go index 39eb03d39..5c34b7923 100644 --- a/relay/audio_handler.go +++ b/relay/audio_handler.go @@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 { service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, info, usage.(*dto.Usage), "") + postConsumeQuota(c, info, usage.(*dto.Usage)) } return nil diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index adce01822..480c21371 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -19,6 +19,22 @@ import ( ) type Adaptor struct { + IsSyncImageModel bool +} + +var syncModels = []string{ + "z-image", + "qwen-image", + "wan2.6", +} + +func isSyncImageModel(modelName string) bool { + for _, m := range syncModels { + if strings.Contains(modelName, m) { + return true + } + } + return false } func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { @@ -45,10 +61,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { case constant.RelayModeRerank: fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl) case constant.RelayModeImagesGenerations: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) + if isSyncImageModel(info.OriginModelName) { + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl) + } else { + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) + } case constant.RelayModeImagesEdits: - if isWanModel(info.OriginModelName) { + if isOldWanModel(info.OriginModelName) { fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl) + } else if isWanModel(info.OriginModelName) { + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl) } else { fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl) } @@ -72,7 +94,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel req.Set("X-DashScope-Plugin", c.GetString("plugin")) } if info.RelayMode == constant.RelayModeImagesGenerations { - req.Set("X-DashScope-Async", "enable") + if isSyncImageModel(info.OriginModelName) { + + } else { + req.Set("X-DashScope-Async", "enable") + } } if info.RelayMode == constant.RelayModeImagesEdits { if isWanModel(info.OriginModelName) { @@ -108,15 +134,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { if info.RelayMode == constant.RelayModeImagesGenerations { - aliRequest, err := oaiImage2Ali(request) + if isSyncImageModel(info.OriginModelName) { + a.IsSyncImageModel = true + } + aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel) if err != nil { - return nil, fmt.Errorf("convert image request failed: %w", err) + return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err) } return aliRequest, nil } else if info.RelayMode == constant.RelayModeImagesEdits { - if isWanModel(info.OriginModelName) { + if isOldWanModel(info.OriginModelName) { return oaiFormEdit2WanxImageEdit(c, info, request) } + if isSyncImageModel(info.OriginModelName) { + if isWanModel(info.OriginModelName) { + a.IsSyncImageModel = false + } else { + a.IsSyncImageModel = true + } + } // ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416 // 如果用户使用表单,则需要解析表单数据 if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { @@ -126,9 +162,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } return aliRequest, nil } else { - aliRequest, err := oaiImage2Ali(request) + aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel) if err != nil { - return nil, fmt.Errorf("convert image request failed: %w", err) + return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err) } return aliRequest, nil } @@ -169,13 +205,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom default: switch info.RelayMode { case constant.RelayModeImagesGenerations: - err, usage = aliImageHandler(c, resp, info) + err, usage = aliImageHandler(a, c, resp, info) case constant.RelayModeImagesEdits: - if isWanModel(info.OriginModelName) { - err, usage = aliImageHandler(c, resp, info) - } else { - err, usage = aliImageEditHandler(c, resp, info) - } + err, usage = aliImageHandler(a, c, resp, info) case constant.RelayModeRerank: err, usage = RerankHandler(c, resp, info) default: diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index 26f14a6c0..75be8ff79 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -1,6 +1,13 @@ package ali -import "github.com/QuantumNous/new-api/dto" +import ( + "strings" + + "github.com/QuantumNous/new-api/dto" + "github.com/QuantumNous/new-api/logger" + "github.com/QuantumNous/new-api/service" + "github.com/gin-gonic/gin" +) type AliMessage struct { Content any `json:"content"` @@ -65,6 +72,7 @@ type AliUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` TotalTokens int `json:"total_tokens"` + ImageCount int `json:"image_count,omitempty"` } type TaskResult struct { @@ -75,14 +83,78 @@ type TaskResult struct { } type AliOutput struct { - TaskId string `json:"task_id,omitempty"` - TaskStatus string `json:"task_status,omitempty"` - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - Message string `json:"message,omitempty"` - Code string `json:"code,omitempty"` - Results []TaskResult `json:"results,omitempty"` - Choices []map[string]any `json:"choices,omitempty"` + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` + Results []TaskResult `json:"results,omitempty"` + Choices []struct { + FinishReason string `json:"finish_reason,omitempty"` + Message struct { + Role string `json:"role,omitempty"` + Content []AliMediaContent `json:"content,omitempty"` + ReasoningContent string `json:"reasoning_content,omitempty"` + } `json:"message,omitempty"` + } `json:"choices,omitempty"` +} + +func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData { + var imageData []dto.ImageData + if len(o.Choices) > 0 { + for _, choice := range o.Choices { + var data dto.ImageData + for _, content := range choice.Message.Content { + if content.Image != "" { + if strings.HasPrefix(content.Image, "http") { + var b64Json string + if responseFormat == "b64_json" { + _, b64, err := service.GetImageFromUrl(content.Image) + if err != nil { + logger.LogError(c, "get_image_data_failed: "+err.Error()) + continue + } + b64Json = b64 + } + data.Url = content.Image + data.B64Json = b64Json + } else { + data.B64Json = content.Image + } + } else if content.Text != "" { + data.RevisedPrompt = content.Text + } + } + imageData = append(imageData, data) + } + } + + return imageData +} + +func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData { + var imageData []dto.ImageData + for _, data := range o.Results { + var b64Json string + if responseFormat == "b64_json" { + _, b64, err := service.GetImageFromUrl(data.Url) + if err != nil { + logger.LogError(c, "get_image_data_failed: "+err.Error()) + continue + } + b64Json = b64 + } else { + b64Json = data.B64Image + } + + imageData = append(imageData, dto.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return imageData } type AliResponse struct { @@ -92,18 +164,26 @@ type AliResponse struct { } type AliImageRequest struct { - Model string `json:"model"` - Input any `json:"input"` - Parameters any `json:"parameters,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` + Model string `json:"model"` + Input any `json:"input"` + Parameters AliImageParameters `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` } type AliImageParameters struct { - Size string `json:"size,omitempty"` - N int `json:"n,omitempty"` - Steps string `json:"steps,omitempty"` - Scale string `json:"scale,omitempty"` - Watermark *bool `json:"watermark,omitempty"` + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + Watermark *bool `json:"watermark,omitempty"` + PromptExtend *bool `json:"prompt_extend,omitempty"` +} + +func (p *AliImageParameters) PromptExtendValue() bool { + if p != nil && p.PromptExtend != nil { + return *p.PromptExtend + } + return false } type AliImageInput struct { diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go index 0e3fe1ea0..22aacf7d7 100644 --- a/relay/channel/ali/image.go +++ b/relay/channel/ali/image.go @@ -21,17 +21,25 @@ import ( "github.com/gin-gonic/gin" ) -func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { +func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) { var imageRequest AliImageRequest imageRequest.Model = request.Model imageRequest.ResponseFormat = request.ResponseFormat logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra) + logger.LogDebug(context.Background(), "oaiImage2Ali request isSync: "+fmt.Sprintf("%v", isSync)) if request.Extra != nil { if val, ok := request.Extra["parameters"]; ok { err := common.Unmarshal(val, &imageRequest.Parameters) if err != nil { return nil, fmt.Errorf("invalid parameters field: %w", err) } + } else { + // 兼容没有parameters字段的情况,从openai标准字段中提取参数 + imageRequest.Parameters = AliImageParameters{ + Size: strings.Replace(request.Size, "x", "*", -1), + N: int(request.N), + Watermark: request.Watermark, + } } if val, ok := request.Extra["input"]; ok { err := common.Unmarshal(val, &imageRequest.Input) @@ -41,23 +49,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { } } - if imageRequest.Parameters == nil { - imageRequest.Parameters = AliImageParameters{ - Size: strings.Replace(request.Size, "x", "*", -1), - N: int(request.N), - Watermark: request.Watermark, + if strings.Contains(request.Model, "z-image") { + // z-image 开启prompt_extend后,按2倍计费 + if imageRequest.Parameters.PromptExtendValue() { + info.PriceData.AddOtherRatio("prompt_extend", 2) } } - if imageRequest.Input == nil { - imageRequest.Input = AliImageInput{ - Prompt: request.Prompt, + // 检查n参数 + if imageRequest.Parameters.N != 0 { + info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N)) + } + + // 同步图片模型和异步图片模型请求格式不一样 + if isSync { + if imageRequest.Input == nil { + imageRequest.Input = AliImageInput{ + Messages: []AliMessage{ + { + Role: "user", + Content: []AliMediaContent{ + { + Text: request.Prompt, + }, + }, + }, + }, + } + } + } else { + if imageRequest.Input == nil { + imageRequest.Input = AliImageInput{ + Prompt: request.Prompt, + } } } return &imageRequest, nil } - func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) { mf := c.Request.MultipartForm if mf == nil { @@ -199,6 +228,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) ( var taskResponse AliResponse var responseBody []byte + time.Sleep(time.Duration(5) * time.Second) + for { logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds)) step++ @@ -238,32 +269,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [ Created: info.StartTime.Unix(), } - for _, data := range response.Output.Results { - var b64Json string - if responseFormat == "b64_json" { - _, b64, err := service.GetImageFromUrl(data.Url) - if err != nil { - logger.LogError(c, "get_image_data_failed: "+err.Error()) - continue - } - b64Json = b64 - } else { - b64Json = data.B64Image - } - - imageResponse.Data = append(imageResponse.Data, dto.ImageData{ - Url: data.Url, - B64Json: b64Json, - RevisedPrompt: "", - }) + if len(response.Output.Results) > 0 { + imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat) + } else if len(response.Output.Choices) > 0 { + imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat) } - var mapResponse map[string]any - _ = common.Unmarshal(originBody, &mapResponse) - imageResponse.Extra = mapResponse + + imageResponse.Metadata = originBody return &imageResponse } -func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { +func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { responseFormat := c.GetString("response_format") var aliTaskResponse AliResponse @@ -282,66 +298,49 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil } - aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId) - if err != nil { - return types.NewError(err, types.ErrorCodeBadResponse), nil - } + var ( + aliResponse *AliResponse + originRespBody []byte + ) - if aliResponse.Output.TaskStatus != "SUCCEEDED" { - return types.WithOpenAIError(types.OpenAIError{ - Message: aliResponse.Output.Message, - Type: "ali_error", - Param: "", - Code: aliResponse.Output.Code, - }, resp.StatusCode), nil - } - - fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat) - jsonResponse, err := common.Marshal(fullTextResponse) - if err != nil { - return types.NewError(err, types.ErrorCodeBadResponseBody), nil - } - service.IOCopyBytesGracefully(c, resp, jsonResponse) - return nil, &dto.Usage{} -} - -func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { - var aliResponse AliResponse - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil - } - - service.CloseResponseBodyGracefully(resp) - err = common.Unmarshal(responseBody, &aliResponse) - if err != nil { - return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil - } - - if aliResponse.Message != "" { - logger.LogError(c, "ali_task_failed: "+aliResponse.Message) - return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil - } - var fullTextResponse dto.ImageResponse - if len(aliResponse.Output.Choices) > 0 { - fullTextResponse = dto.ImageResponse{ - Created: info.StartTime.Unix(), - Data: []dto.ImageData{ - { - Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string), - B64Json: "", - }, - }, + if a.IsSyncImageModel { + aliResponse = &aliTaskResponse + originRespBody = responseBody + } else { + // 异步图片模型需要轮询任务结果 + aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId) + if err != nil { + return types.NewError(err, types.ErrorCodeBadResponse), nil + } + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return types.WithOpenAIError(types.OpenAIError{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, resp.StatusCode), nil } } - var mapResponse map[string]any - _ = common.Unmarshal(responseBody, &mapResponse) - fullTextResponse.Extra = mapResponse - jsonResponse, err := common.Marshal(fullTextResponse) + //logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody)) + if a.IsSyncImageModel { + logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody)) + } else { + logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody)) + } + + imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat) + // 可能生成多张图片,修正计费数量n + if aliResponse.Usage.ImageCount != 0 { + info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount)) + } else if len(imageResponses.Data) != 0 { + info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data))) + } + jsonResponse, err := common.Marshal(imageResponses) if err != nil { return types.NewError(err, types.ErrorCodeBadResponseBody), nil } service.IOCopyBytesGracefully(c, resp, jsonResponse) + return nil, &dto.Usage{} } diff --git a/relay/channel/ali/image_wan.go b/relay/channel/ali/image_wan.go index 4bd1a2701..90ee48a0b 100644 --- a/relay/channel/ali/image_wan.go +++ b/relay/channel/ali/image_wan.go @@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil { return nil, fmt.Errorf("get image base64s from form failed: %w", err) } - wanParams := WanImageParameters{ + //wanParams := WanImageParameters{ + // N: int(request.N), + //} + imageRequest.Input = wanInput + imageRequest.Parameters = AliImageParameters{ N: int(request.N), } - imageRequest.Input = wanInput - imageRequest.Parameters = wanParams + info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N)) + return &imageRequest, nil } +func isOldWanModel(modelName string) bool { + return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6") +} + func isWanModel(modelName string) bool { return strings.Contains(modelName, "wan") } diff --git a/relay/compatible_handler.go b/relay/compatible_handler.go index d92c990a7..97649ca96 100644 --- a/relay/compatible_handler.go +++ b/relay/compatible_handler.go @@ -184,19 +184,19 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 { service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, info, usage.(*dto.Usage), "") + postConsumeQuota(c, info, usage.(*dto.Usage)) } return nil } -func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { +func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) { if usage == nil { usage = &dto.Usage{ PromptTokens: relayInfo.GetEstimatePromptTokens(), CompletionTokens: 0, TotalTokens: relayInfo.GetEstimatePromptTokens(), } - extraContent += "(可能是请求出错)" + extraContent = append(extraContent, "上游无计费信息") } useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -246,8 +246,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s", - webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) + extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s", + webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())) } } else if strings.HasSuffix(modelName, "search-preview") { // search-preview 模型不支持 response api @@ -258,8 +258,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize) dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s", - searchContextSize, dWebSearchQuota.String()) + extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s", + searchContextSize, dWebSearchQuota.String())) } // claude web search tool 计费 var dClaudeWebSearchQuota decimal.Decimal @@ -269,8 +269,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand() dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount))) - extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", - claudeWebSearchCallCount, dClaudeWebSearchQuota.String()) + extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", + claudeWebSearchCallCount, dClaudeWebSearchQuota.String())) } // file search tool 计费 var dFileSearchQuota decimal.Decimal @@ -281,8 +281,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice). Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) - extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s", - fileSearchTool.CallCount, dFileSearchQuota.String()) + extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s", + fileSearchTool.CallCount, dFileSearchQuota.String())) } } var dImageGenerationCallQuota decimal.Decimal @@ -290,7 +290,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage if ctx.GetBool("image_generation_call") { imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) - extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()) + extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())) } var quotaCalculateDecimal decimal.Decimal @@ -331,7 +331,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage // 重新计算 base tokens baseTokens = baseTokens.Sub(dAudioTokens) audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit) - extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()) + extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())) } } promptQuota := baseTokens.Add(cachedTokensWithRatio). @@ -356,17 +356,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage // 添加 image generation call 计费 quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) + if len(relayInfo.PriceData.OtherRatios) > 0 { + for key, otherRatio := range relayInfo.PriceData.OtherRatios { + dOtherRatio := decimal.NewFromFloat(otherRatio) + quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio) + extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio)) + } + } + quota := int(quotaCalculateDecimal.Round(0).IntPart()) totalTokens := promptTokens + completionTokens - var logContent string + //var logContent string // record all the consume log even if quota is 0 if totalTokens == 0 { // in this case, must be some error happened // we cannot just return, because we may have to return the pre-consumed quota quota = 0 - logContent += fmt.Sprintf("(可能是上游超时)") + extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)") logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) } else { @@ -405,15 +413,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage logModel := modelName if strings.HasPrefix(logModel, "gpt-4-gizmo") { logModel = "gpt-4-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", modelName) + extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName)) } if strings.HasPrefix(logModel, "gpt-4o-gizmo") { logModel = "gpt-4o-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", modelName) - } - if extraContent != "" { - logContent += ", " + extraContent + extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName)) } + logContent := strings.Join(extraContent, ", ") other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) if imageTokens != 0 { other["image"] = true diff --git a/relay/embedding_handler.go b/relay/embedding_handler.go index 740ca400e..2cedf02b5 100644 --- a/relay/embedding_handler.go +++ b/relay/embedding_handler.go @@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, info, usage.(*dto.Usage), "") + postConsumeQuota(c, info, usage.(*dto.Usage)) return nil } diff --git a/relay/gemini_handler.go b/relay/gemini_handler.go index af13341bf..79ffba515 100644 --- a/relay/gemini_handler.go +++ b/relay/gemini_handler.go @@ -193,7 +193,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ return openaiErr } - postConsumeQuota(c, info, usage.(*dto.Usage), "") + postConsumeQuota(c, info, usage.(*dto.Usage)) return nil } @@ -292,6 +292,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI return openaiErr } - postConsumeQuota(c, info, usage.(*dto.Usage), "") + postConsumeQuota(c, info, usage.(*dto.Usage)) return nil } diff --git a/relay/image_handler.go b/relay/image_handler.go index b58968402..f110f4e86 100644 --- a/relay/image_handler.go +++ b/relay/image_handler.go @@ -124,12 +124,18 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type quality = "hd" } - var logContent string + var logContent []string if len(request.Size) > 0 { - logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N) + logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size)) + } + if len(quality) > 0 { + logContent = append(logContent, fmt.Sprintf("品质 %s", quality)) + } + if request.N > 0 { + logContent = append(logContent, fmt.Sprintf("生成数量 %d", request.N)) } - postConsumeQuota(c, info, usage.(*dto.Usage), logContent) + postConsumeQuota(c, info, usage.(*dto.Usage), logContent...) return nil } diff --git a/relay/rerank_handler.go b/relay/rerank_handler.go index 3efc45079..9a50fd271 100644 --- a/relay/rerank_handler.go +++ b/relay/rerank_handler.go @@ -95,6 +95,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ service.ResetStatusCode(newAPIError, statusCodeMappingStr) return newAPIError } - postConsumeQuota(c, info, usage.(*dto.Usage), "") + postConsumeQuota(c, info, usage.(*dto.Usage)) return nil } diff --git a/relay/responses_handler.go b/relay/responses_handler.go index 9460356d6..5c3d9a426 100644 --- a/relay/responses_handler.go +++ b/relay/responses_handler.go @@ -107,7 +107,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError * if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { - postConsumeQuota(c, info, usage.(*dto.Usage), "") + postConsumeQuota(c, info, usage.(*dto.Usage)) } return nil } diff --git a/types/price_data.go b/types/price_data.go index 93044f865..3f7121b8c 100644 --- a/types/price_data.go +++ b/types/price_data.go @@ -26,12 +26,22 @@ type PriceData struct { GroupRatioInfo GroupRatioInfo } +func (p *PriceData) AddOtherRatio(key string, ratio float64) { + if p.OtherRatios == nil { + p.OtherRatios = make(map[string]float64) + } + if ratio <= 0 { + return + } + p.OtherRatios[key] = ratio +} + type PerCallPriceData struct { ModelPrice float64 Quota int GroupRatioInfo GroupRatioInfo } -func (p PriceData) ToSetting() string { +func (p *PriceData) ToSetting() string { return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) }