feat(adaptor): 新适配百炼多种图片生成模型

- wan2.6系列生图与编辑,适配多图生成计费
- wan2.5系列生图与编辑
- z-image-turbo生图,适配prompt_extend计费
This commit is contained in:
CaIon
2025-12-29 22:58:32 +08:00
parent 8063897998
commit 48d358faec
16 changed files with 336 additions and 155 deletions

View File

@@ -1,6 +1,7 @@
package controller package controller
import ( import (
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@@ -149,6 +150,24 @@ func AddToken(c *gin.Context) {
}) })
return return
} }
// 非无限额度时,检查额度值是否超出有效范围
if !token.UnlimitedQuota {
if token.RemainQuota < 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "额度值不能为负数",
})
return
}
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
if token.RemainQuota > maxQuotaValue {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
})
return
}
}
key, err := common.GenerateKey() key, err := common.GenerateKey()
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -216,6 +235,23 @@ func UpdateToken(c *gin.Context) {
}) })
return return
} }
if !token.UnlimitedQuota {
if token.RemainQuota < 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "额度值不能为负数",
})
return
}
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
if token.RemainQuota > maxQuotaValue {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
})
return
}
}
cleanToken, err := model.GetTokenByIds(token.Id, userId) cleanToken, err := model.GetTokenByIds(token.Id, userId)
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
@@ -261,7 +297,6 @@ func UpdateToken(c *gin.Context) {
"message": "", "message": "",
"data": cleanToken, "data": cleanToken,
}) })
return
} }
type TokenBatch struct { type TokenBatch struct {

View File

@@ -169,7 +169,7 @@ func (i *ImageRequest) SetModelName(modelName string) {
type ImageResponse struct { type ImageResponse struct {
Data []ImageData `json:"data"` Data []ImageData `json:"data"`
Created int64 `json:"created"` Created int64 `json:"created"`
Extra any `json:"extra,omitempty"` Metadata json.RawMessage `json:"metadata,omitempty"`
} }
type ImageData struct { type ImageData struct {
Url string `json:"url"` Url string `json:"url"`

View File

@@ -23,6 +23,8 @@ type FormatJsonSchema struct {
Strict json.RawMessage `json:"strict,omitempty"` Strict json.RawMessage `json:"strict,omitempty"`
} }
// GeneralOpenAIRequest represents a general request structure for OpenAI-compatible APIs.
// 参数增加规范无引用的参数必须使用json.RawMessage类型并添加omitempty标签
type GeneralOpenAIRequest struct { type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"` Messages []Message `json:"messages,omitempty"`
@@ -82,8 +84,9 @@ type GeneralOpenAIRequest struct {
Reasoning json.RawMessage `json:"reasoning,omitempty"` Reasoning json.RawMessage `json:"reasoning,omitempty"`
// Ali Qwen Params // Ali Qwen Params
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"` VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
EnableThinking any `json:"enable_thinking,omitempty"` EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"` ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
EnableSearch json.RawMessage `json:"enable_search,omitempty"`
// ollama Params // ollama Params
Think json.RawMessage `json:"think,omitempty"` Think json.RawMessage `json:"think,omitempty"`
// baidu v2 // baidu v2

View File

@@ -188,6 +188,7 @@ func InjectUmamiAnalytics() {
analyticsInjectBuilder.WriteString(umamiSiteID) analyticsInjectBuilder.WriteString(umamiSiteID)
analyticsInjectBuilder.WriteString("\"></script>") analyticsInjectBuilder.WriteString("\"></script>")
} }
analyticsInjectBuilder.WriteString("<!--Umami QuantumNous-->\n")
analyticsInject := analyticsInjectBuilder.String() analyticsInject := analyticsInjectBuilder.String()
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject)) indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
} }
@@ -209,6 +210,7 @@ func InjectGoogleAnalytics() {
analyticsInjectBuilder.WriteString("');") analyticsInjectBuilder.WriteString("');")
analyticsInjectBuilder.WriteString("</script>") analyticsInjectBuilder.WriteString("</script>")
} }
analyticsInjectBuilder.WriteString("<!--Google Analytics QuantumNous-->\n")
analyticsInject := analyticsInjectBuilder.String() analyticsInject := analyticsInjectBuilder.String()
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject)) indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
} }

View File

@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 { if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else { } else {
postConsumeQuota(c, info, usage.(*dto.Usage), "") postConsumeQuota(c, info, usage.(*dto.Usage))
} }
return nil return nil

View File

@@ -19,6 +19,22 @@ import (
) )
type Adaptor struct { type Adaptor struct {
IsSyncImageModel bool
}
var syncModels = []string{
"z-image",
"qwen-image",
"wan2.6",
}
func isSyncImageModel(modelName string) bool {
for _, m := range syncModels {
if strings.Contains(modelName, m) {
return true
}
}
return false
} }
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
@@ -45,10 +61,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
case constant.RelayModeRerank: case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
if isSyncImageModel(info.OriginModelName) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
} else {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
}
case constant.RelayModeImagesEdits: case constant.RelayModeImagesEdits:
if isWanModel(info.OriginModelName) { if isOldWanModel(info.OriginModelName) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
} else if isWanModel(info.OriginModelName) {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl)
} else { } else {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
} }
@@ -72,8 +94,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
req.Set("X-DashScope-Plugin", c.GetString("plugin")) req.Set("X-DashScope-Plugin", c.GetString("plugin"))
} }
if info.RelayMode == constant.RelayModeImagesGenerations { if info.RelayMode == constant.RelayModeImagesGenerations {
if isSyncImageModel(info.OriginModelName) {
} else {
req.Set("X-DashScope-Async", "enable") req.Set("X-DashScope-Async", "enable")
} }
}
if info.RelayMode == constant.RelayModeImagesEdits { if info.RelayMode == constant.RelayModeImagesEdits {
if isWanModel(info.OriginModelName) { if isWanModel(info.OriginModelName) {
req.Set("X-DashScope-Async", "enable") req.Set("X-DashScope-Async", "enable")
@@ -108,15 +134,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
if info.RelayMode == constant.RelayModeImagesGenerations { if info.RelayMode == constant.RelayModeImagesGenerations {
aliRequest, err := oaiImage2Ali(request) if isSyncImageModel(info.OriginModelName) {
a.IsSyncImageModel = true
}
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert image request failed: %w", err) return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
} }
return aliRequest, nil return aliRequest, nil
} else if info.RelayMode == constant.RelayModeImagesEdits { } else if info.RelayMode == constant.RelayModeImagesEdits {
if isWanModel(info.OriginModelName) { if isOldWanModel(info.OriginModelName) {
return oaiFormEdit2WanxImageEdit(c, info, request) return oaiFormEdit2WanxImageEdit(c, info, request)
} }
if isSyncImageModel(info.OriginModelName) {
if isWanModel(info.OriginModelName) {
a.IsSyncImageModel = false
} else {
a.IsSyncImageModel = true
}
}
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416 // ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
// 如果用户使用表单,则需要解析表单数据 // 如果用户使用表单,则需要解析表单数据
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
@@ -126,9 +162,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
return aliRequest, nil return aliRequest, nil
} else { } else {
aliRequest, err := oaiImage2Ali(request) aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
if err != nil { if err != nil {
return nil, fmt.Errorf("convert image request failed: %w", err) return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
} }
return aliRequest, nil return aliRequest, nil
} }
@@ -169,13 +205,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
default: default:
switch info.RelayMode { switch info.RelayMode {
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info) err, usage = aliImageHandler(a, c, resp, info)
case constant.RelayModeImagesEdits: case constant.RelayModeImagesEdits:
if isWanModel(info.OriginModelName) { err, usage = aliImageHandler(a, c, resp, info)
err, usage = aliImageHandler(c, resp, info)
} else {
err, usage = aliImageEditHandler(c, resp, info)
}
case constant.RelayModeRerank: case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info) err, usage = RerankHandler(c, resp, info)
default: default:

View File

@@ -1,6 +1,13 @@
package ali package ali
import "github.com/QuantumNous/new-api/dto" import (
"strings"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
type AliMessage struct { type AliMessage struct {
Content any `json:"content"` Content any `json:"content"`
@@ -65,6 +72,7 @@ type AliUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
ImageCount int `json:"image_count,omitempty"`
} }
type TaskResult struct { type TaskResult struct {
@@ -82,7 +90,71 @@ type AliOutput struct {
Message string `json:"message,omitempty"` Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"` Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"` Results []TaskResult `json:"results,omitempty"`
Choices []map[string]any `json:"choices,omitempty"` Choices []struct {
FinishReason string `json:"finish_reason,omitempty"`
Message struct {
Role string `json:"role,omitempty"`
Content []AliMediaContent `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
} `json:"message,omitempty"`
} `json:"choices,omitempty"`
}
func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
var imageData []dto.ImageData
if len(o.Choices) > 0 {
for _, choice := range o.Choices {
var data dto.ImageData
for _, content := range choice.Message.Content {
if content.Image != "" {
if strings.HasPrefix(content.Image, "http") {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(content.Image)
if err != nil {
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
}
data.Url = content.Image
data.B64Json = b64Json
} else {
data.B64Json = content.Image
}
} else if content.Text != "" {
data.RevisedPrompt = content.Text
}
}
imageData = append(imageData, data)
}
}
return imageData
}
func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
var imageData []dto.ImageData
for _, data := range o.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageData = append(imageData, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return imageData
} }
type AliResponse struct { type AliResponse struct {
@@ -94,7 +166,7 @@ type AliResponse struct {
type AliImageRequest struct { type AliImageRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input any `json:"input"` Input any `json:"input"`
Parameters any `json:"parameters,omitempty"` Parameters AliImageParameters `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"` ResponseFormat string `json:"response_format,omitempty"`
} }
@@ -104,6 +176,14 @@ type AliImageParameters struct {
Steps string `json:"steps,omitempty"` Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"` Scale string `json:"scale,omitempty"`
Watermark *bool `json:"watermark,omitempty"` Watermark *bool `json:"watermark,omitempty"`
PromptExtend *bool `json:"prompt_extend,omitempty"`
}
func (p *AliImageParameters) PromptExtendValue() bool {
if p != nil && p.PromptExtend != nil {
return *p.PromptExtend
}
return false
} }
type AliImageInput struct { type AliImageInput struct {

View File

@@ -21,17 +21,25 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) { func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
var imageRequest AliImageRequest var imageRequest AliImageRequest
imageRequest.Model = request.Model imageRequest.Model = request.Model
imageRequest.ResponseFormat = request.ResponseFormat imageRequest.ResponseFormat = request.ResponseFormat
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra) logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
logger.LogDebug(context.Background(), "oaiImage2Ali request isSync: "+fmt.Sprintf("%v", isSync))
if request.Extra != nil { if request.Extra != nil {
if val, ok := request.Extra["parameters"]; ok { if val, ok := request.Extra["parameters"]; ok {
err := common.Unmarshal(val, &imageRequest.Parameters) err := common.Unmarshal(val, &imageRequest.Parameters)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid parameters field: %w", err) return nil, fmt.Errorf("invalid parameters field: %w", err)
} }
} else {
// 兼容没有parameters字段的情况从openai标准字段中提取参数
imageRequest.Parameters = AliImageParameters{
Size: strings.Replace(request.Size, "x", "*", -1),
N: int(request.N),
Watermark: request.Watermark,
}
} }
if val, ok := request.Extra["input"]; ok { if val, ok := request.Extra["input"]; ok {
err := common.Unmarshal(val, &imageRequest.Input) err := common.Unmarshal(val, &imageRequest.Input)
@@ -41,23 +49,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
} }
} }
if imageRequest.Parameters == nil { if strings.Contains(request.Model, "z-image") {
imageRequest.Parameters = AliImageParameters{ // z-image 开启prompt_extend后按2倍计费
Size: strings.Replace(request.Size, "x", "*", -1), if imageRequest.Parameters.PromptExtendValue() {
N: int(request.N), info.PriceData.AddOtherRatio("prompt_extend", 2)
Watermark: request.Watermark,
} }
} }
// 检查n参数
if imageRequest.Parameters.N != 0 {
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
}
// 同步图片模型和异步图片模型请求格式不一样
if isSync {
if imageRequest.Input == nil {
imageRequest.Input = AliImageInput{
Messages: []AliMessage{
{
Role: "user",
Content: []AliMediaContent{
{
Text: request.Prompt,
},
},
},
},
}
}
} else {
if imageRequest.Input == nil { if imageRequest.Input == nil {
imageRequest.Input = AliImageInput{ imageRequest.Input = AliImageInput{
Prompt: request.Prompt, Prompt: request.Prompt,
} }
} }
}
return &imageRequest, nil return &imageRequest, nil
} }
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) { func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
mf := c.Request.MultipartForm mf := c.Request.MultipartForm
if mf == nil { if mf == nil {
@@ -199,6 +228,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
var taskResponse AliResponse var taskResponse AliResponse
var responseBody []byte var responseBody []byte
time.Sleep(time.Duration(5) * time.Second)
for { for {
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds)) logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
step++ step++
@@ -238,32 +269,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
Created: info.StartTime.Unix(), Created: info.StartTime.Unix(),
} }
for _, data := range response.Output.Results { if len(response.Output.Results) > 0 {
var b64Json string imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
if responseFormat == "b64_json" { } else if len(response.Output.Choices) > 0 {
_, b64, err := service.GetImageFromUrl(data.Url) imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
if err != nil {
logger.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
} }
imageResponse.Data = append(imageResponse.Data, dto.ImageData{ imageResponse.Metadata = originBody
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
var mapResponse map[string]any
_ = common.Unmarshal(originBody, &mapResponse)
imageResponse.Extra = mapResponse
return &imageResponse return &imageResponse
} }
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseFormat := c.GetString("response_format") responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse var aliTaskResponse AliResponse
@@ -282,11 +298,20 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
} }
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId) var (
aliResponse *AliResponse
originRespBody []byte
)
if a.IsSyncImageModel {
aliResponse = &aliTaskResponse
originRespBody = responseBody
} else {
// 异步图片模型需要轮询任务结果
aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponse), nil return types.NewError(err, types.ErrorCodeBadResponse), nil
} }
if aliResponse.Output.TaskStatus != "SUCCEEDED" { if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return types.WithOpenAIError(types.OpenAIError{ return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Output.Message, Message: aliResponse.Output.Message,
@@ -295,53 +320,27 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
Code: aliResponse.Output.Code, Code: aliResponse.Output.Code,
}, resp.StatusCode), nil }, resp.StatusCode), nil
} }
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat) //logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
jsonResponse, err := common.Marshal(fullTextResponse) if a.IsSyncImageModel {
if err != nil { logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
return types.NewError(err, types.ErrorCodeBadResponseBody), nil } else {
} logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
service.IOCopyBytesGracefully(c, resp, jsonResponse) }
return nil, &dto.Usage{}
} imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
// 可能生成多张图片修正计费数量n
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { if aliResponse.Usage.ImageCount != 0 {
var aliResponse AliResponse info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
responseBody, err := io.ReadAll(resp.Body) } else if len(imageResponses.Data) != 0 {
if err != nil { info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil }
} jsonResponse, err := common.Marshal(imageResponses)
service.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Message != "" {
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
}
var fullTextResponse dto.ImageResponse
if len(aliResponse.Output.Choices) > 0 {
fullTextResponse = dto.ImageResponse{
Created: info.StartTime.Unix(),
Data: []dto.ImageData{
{
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
B64Json: "",
},
},
}
}
var mapResponse map[string]any
_ = common.Unmarshal(responseBody, &mapResponse)
fullTextResponse.Extra = mapResponse
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewError(err, types.ErrorCodeBadResponseBody), nil
} }
service.IOCopyBytesGracefully(c, resp, jsonResponse) service.IOCopyBytesGracefully(c, resp, jsonResponse)
return nil, &dto.Usage{} return nil, &dto.Usage{}
} }

View File

@@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil { if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
return nil, fmt.Errorf("get image base64s from form failed: %w", err) return nil, fmt.Errorf("get image base64s from form failed: %w", err)
} }
wanParams := WanImageParameters{ //wanParams := WanImageParameters{
// N: int(request.N),
//}
imageRequest.Input = wanInput
imageRequest.Parameters = AliImageParameters{
N: int(request.N), N: int(request.N),
} }
imageRequest.Input = wanInput info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
imageRequest.Parameters = wanParams
return &imageRequest, nil return &imageRequest, nil
} }
func isOldWanModel(modelName string) bool {
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
}
func isWanModel(modelName string) bool { func isWanModel(modelName string) bool {
return strings.Contains(modelName, "wan") return strings.Contains(modelName, "wan")
} }

View File

@@ -184,19 +184,19 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 { if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else { } else {
postConsumeQuota(c, info, usage.(*dto.Usage), "") postConsumeQuota(c, info, usage.(*dto.Usage))
} }
return nil return nil
} }
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) { func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
if usage == nil { if usage == nil {
usage = &dto.Usage{ usage = &dto.Usage{
PromptTokens: relayInfo.GetEstimatePromptTokens(), PromptTokens: relayInfo.GetEstimatePromptTokens(),
CompletionTokens: 0, CompletionTokens: 0,
TotalTokens: relayInfo.GetEstimatePromptTokens(), TotalTokens: relayInfo.GetEstimatePromptTokens(),
} }
extraContent += "(可能是请求出错)" extraContent = append(extraContent, "上游无计费信息")
} }
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
@@ -246,8 +246,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))). Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s调用花费 %s", extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s调用花费 %s",
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()) webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
} }
} else if strings.HasSuffix(modelName, "search-preview") { } else if strings.HasSuffix(modelName, "search-preview") {
// search-preview 模型不支持 response api // search-preview 模型不支持 response api
@@ -258,8 +258,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize) webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice). dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s调用花费 %s", extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s调用花费 %s",
searchContextSize, dWebSearchQuota.String()) searchContextSize, dWebSearchQuota.String()))
} }
// claude web search tool 计费 // claude web search tool 计费
var dClaudeWebSearchQuota decimal.Decimal var dClaudeWebSearchQuota decimal.Decimal
@@ -269,8 +269,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand() claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice). dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount))) Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s", extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()) claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
} }
// file search tool 计费 // file search tool 计费
var dFileSearchQuota decimal.Decimal var dFileSearchQuota decimal.Decimal
@@ -281,8 +281,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice). dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))). Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit) Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s", extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
fileSearchTool.CallCount, dFileSearchQuota.String()) fileSearchTool.CallCount, dFileSearchQuota.String()))
} }
} }
var dImageGenerationCallQuota decimal.Decimal var dImageGenerationCallQuota decimal.Decimal
@@ -290,7 +290,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
if ctx.GetBool("image_generation_call") { if ctx.GetBool("image_generation_call") {
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size")) imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit) dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()) extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
} }
var quotaCalculateDecimal decimal.Decimal var quotaCalculateDecimal decimal.Decimal
@@ -331,7 +331,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
// 重新计算 base tokens // 重新计算 base tokens
baseTokens = baseTokens.Sub(dAudioTokens) baseTokens = baseTokens.Sub(dAudioTokens)
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit) audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()) extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
} }
} }
promptQuota := baseTokens.Add(cachedTokensWithRatio). promptQuota := baseTokens.Add(cachedTokensWithRatio).
@@ -356,17 +356,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
// 添加 image generation call 计费 // 添加 image generation call 计费
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota) quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
if len(relayInfo.PriceData.OtherRatios) > 0 {
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
dOtherRatio := decimal.NewFromFloat(otherRatio)
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
}
}
quota := int(quotaCalculateDecimal.Round(0).IntPart()) quota := int(quotaCalculateDecimal.Round(0).IntPart())
totalTokens := promptTokens + completionTokens totalTokens := promptTokens + completionTokens
var logContent string //var logContent string
// record all the consume log even if quota is 0 // record all the consume log even if quota is 0
if totalTokens == 0 { if totalTokens == 0 {
// in this case, must be some error happened // in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota // we cannot just return, because we may have to return the pre-consumed quota
quota = 0 quota = 0
logContent += fmt.Sprintf("(可能是上游超时)") extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota)) "tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else { } else {
@@ -405,15 +413,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
logModel := modelName logModel := modelName
if strings.HasPrefix(logModel, "gpt-4-gizmo") { if strings.HasPrefix(logModel, "gpt-4-gizmo") {
logModel = "gpt-4-gizmo-*" logModel = "gpt-4-gizmo-*"
logContent += fmt.Sprintf("模型 %s", modelName) extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
} }
if strings.HasPrefix(logModel, "gpt-4o-gizmo") { if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
logModel = "gpt-4o-gizmo-*" logModel = "gpt-4o-gizmo-*"
logContent += fmt.Sprintf("模型 %s", modelName) extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
}
if extraContent != "" {
logContent += ", " + extraContent
} }
logContent := strings.Join(extraContent, ", ")
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio) other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
if imageTokens != 0 { if imageTokens != 0 {
other["image"] = true other["image"] = true

View File

@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError
} }
postConsumeQuota(c, info, usage.(*dto.Usage), "") postConsumeQuota(c, info, usage.(*dto.Usage))
return nil return nil
} }

View File

@@ -193,7 +193,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return openaiErr return openaiErr
} }
postConsumeQuota(c, info, usage.(*dto.Usage), "") postConsumeQuota(c, info, usage.(*dto.Usage))
return nil return nil
} }
@@ -292,6 +292,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
return openaiErr return openaiErr
} }
postConsumeQuota(c, info, usage.(*dto.Usage), "") postConsumeQuota(c, info, usage.(*dto.Usage))
return nil return nil
} }

View File

@@ -124,12 +124,18 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
quality = "hd" quality = "hd"
} }
var logContent string var logContent []string
if len(request.Size) > 0 { if len(request.Size) > 0 {
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N) logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
}
if len(quality) > 0 {
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
}
if request.N > 0 {
logContent = append(logContent, fmt.Sprintf("生成数量 %d", request.N))
} }
postConsumeQuota(c, info, usage.(*dto.Usage), logContent) postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
return nil return nil
} }

View File

@@ -95,6 +95,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
service.ResetStatusCode(newAPIError, statusCodeMappingStr) service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError return newAPIError
} }
postConsumeQuota(c, info, usage.(*dto.Usage), "") postConsumeQuota(c, info, usage.(*dto.Usage))
return nil return nil
} }

View File

@@ -107,7 +107,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") { if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
} else { } else {
postConsumeQuota(c, info, usage.(*dto.Usage), "") postConsumeQuota(c, info, usage.(*dto.Usage))
} }
return nil return nil
} }

View File

@@ -26,12 +26,22 @@ type PriceData struct {
GroupRatioInfo GroupRatioInfo GroupRatioInfo GroupRatioInfo
} }
func (p *PriceData) AddOtherRatio(key string, ratio float64) {
if p.OtherRatios == nil {
p.OtherRatios = make(map[string]float64)
}
if ratio <= 0 {
return
}
p.OtherRatios[key] = ratio
}
type PerCallPriceData struct { type PerCallPriceData struct {
ModelPrice float64 ModelPrice float64
Quota int Quota int
GroupRatioInfo GroupRatioInfo GroupRatioInfo GroupRatioInfo
} }
func (p PriceData) ToSetting() string { func (p *PriceData) ToSetting() string {
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio) return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
} }