mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-29 23:10:35 +00:00
feat(adaptor): 新适配百炼多种图片生成模型
- wan2.6系列生图与编辑,适配多图生成计费 - wan2.5系列生图与编辑 - z-image-turbo生图,适配prompt_extend计费
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -149,6 +150,24 @@ func AddToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// 非无限额度时,检查额度值是否超出有效范围
|
||||
if !token.UnlimitedQuota {
|
||||
if token.RemainQuota < 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "额度值不能为负数",
|
||||
})
|
||||
return
|
||||
}
|
||||
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
|
||||
if token.RemainQuota > maxQuotaValue {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
key, err := common.GenerateKey()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -216,6 +235,23 @@ func UpdateToken(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if !token.UnlimitedQuota {
|
||||
if token.RemainQuota < 0 {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "额度值不能为负数",
|
||||
})
|
||||
return
|
||||
}
|
||||
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
|
||||
if token.RemainQuota > maxQuotaValue {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||
if err != nil {
|
||||
common.ApiError(c, err)
|
||||
@@ -261,7 +297,6 @@ func UpdateToken(c *gin.Context) {
|
||||
"message": "",
|
||||
"data": cleanToken,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type TokenBatch struct {
|
||||
|
||||
@@ -167,9 +167,9 @@ func (i *ImageRequest) SetModelName(modelName string) {
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Data []ImageData `json:"data"`
|
||||
Created int64 `json:"created"`
|
||||
Extra any `json:"extra,omitempty"`
|
||||
Data []ImageData `json:"data"`
|
||||
Created int64 `json:"created"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
}
|
||||
type ImageData struct {
|
||||
Url string `json:"url"`
|
||||
|
||||
@@ -23,6 +23,8 @@ type FormatJsonSchema struct {
|
||||
Strict json.RawMessage `json:"strict,omitempty"`
|
||||
}
|
||||
|
||||
// GeneralOpenAIRequest represents a general request structure for OpenAI-compatible APIs.
|
||||
// 参数增加规范:无引用的参数必须使用json.RawMessage类型,并添加omitempty标签
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
@@ -82,8 +84,9 @@ type GeneralOpenAIRequest struct {
|
||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||
// Ali Qwen Params
|
||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||
EnableThinking any `json:"enable_thinking,omitempty"`
|
||||
EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
|
||||
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
|
||||
EnableSearch json.RawMessage `json:"enable_search,omitempty"`
|
||||
// ollama Params
|
||||
Think json.RawMessage `json:"think,omitempty"`
|
||||
// baidu v2
|
||||
|
||||
2
main.go
2
main.go
@@ -188,6 +188,7 @@ func InjectUmamiAnalytics() {
|
||||
analyticsInjectBuilder.WriteString(umamiSiteID)
|
||||
analyticsInjectBuilder.WriteString("\"></script>")
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<!--Umami QuantumNous-->\n")
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
@@ -209,6 +210,7 @@ func InjectGoogleAnalytics() {
|
||||
analyticsInjectBuilder.WriteString("');")
|
||||
analyticsInjectBuilder.WriteString("</script>")
|
||||
}
|
||||
analyticsInjectBuilder.WriteString("<!--Google Analytics QuantumNous-->\n")
|
||||
analyticsInject := analyticsInjectBuilder.String()
|
||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -19,6 +19,22 @@ import (
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
IsSyncImageModel bool
|
||||
}
|
||||
|
||||
var syncModels = []string{
|
||||
"z-image",
|
||||
"qwen-image",
|
||||
"wan2.6",
|
||||
}
|
||||
|
||||
func isSyncImageModel(modelName string) bool {
|
||||
for _, m := range syncModels {
|
||||
if strings.Contains(modelName, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||
@@ -45,10 +61,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
case constant.RelayModeRerank:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||
}
|
||||
case constant.RelayModeImagesEdits:
|
||||
if isWanModel(info.OriginModelName) {
|
||||
if isOldWanModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
|
||||
} else if isWanModel(info.OriginModelName) {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl)
|
||||
} else {
|
||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||
}
|
||||
@@ -72,7 +94,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
|
||||
} else {
|
||||
req.Set("X-DashScope-Async", "enable")
|
||||
}
|
||||
}
|
||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
@@ -108,15 +134,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
a.IsSyncImageModel = true
|
||||
}
|
||||
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
if isOldWanModel(info.OriginModelName) {
|
||||
return oaiFormEdit2WanxImageEdit(c, info, request)
|
||||
}
|
||||
if isSyncImageModel(info.OriginModelName) {
|
||||
if isWanModel(info.OriginModelName) {
|
||||
a.IsSyncImageModel = false
|
||||
} else {
|
||||
a.IsSyncImageModel = true
|
||||
}
|
||||
}
|
||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||
// 如果用户使用表单,则需要解析表单数据
|
||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||
@@ -126,9 +162,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
||||
}
|
||||
return aliRequest, nil
|
||||
} else {
|
||||
aliRequest, err := oaiImage2Ali(request)
|
||||
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
||||
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||
}
|
||||
return aliRequest, nil
|
||||
}
|
||||
@@ -169,13 +205,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
||||
default:
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
err, usage = aliImageHandler(a, c, resp, info)
|
||||
case constant.RelayModeImagesEdits:
|
||||
if isWanModel(info.OriginModelName) {
|
||||
err, usage = aliImageHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = aliImageEditHandler(c, resp, info)
|
||||
}
|
||||
err, usage = aliImageHandler(a, c, resp, info)
|
||||
case constant.RelayModeRerank:
|
||||
err, usage = RerankHandler(c, resp, info)
|
||||
default:
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
package ali
|
||||
|
||||
import "github.com/QuantumNous/new-api/dto"
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/logger"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type AliMessage struct {
|
||||
Content any `json:"content"`
|
||||
@@ -65,6 +72,7 @@ type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
type TaskResult struct {
|
||||
@@ -75,14 +83,78 @@ type TaskResult struct {
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []map[string]any `json:"choices,omitempty"`
|
||||
TaskId string `json:"task_id,omitempty"`
|
||||
TaskStatus string `json:"task_status,omitempty"`
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Results []TaskResult `json:"results,omitempty"`
|
||||
Choices []struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Message struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content []AliMediaContent `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
} `json:"message,omitempty"`
|
||||
} `json:"choices,omitempty"`
|
||||
}
|
||||
|
||||
func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||
var imageData []dto.ImageData
|
||||
if len(o.Choices) > 0 {
|
||||
for _, choice := range o.Choices {
|
||||
var data dto.ImageData
|
||||
for _, content := range choice.Message.Content {
|
||||
if content.Image != "" {
|
||||
if strings.HasPrefix(content.Image, "http") {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(content.Image)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
}
|
||||
data.Url = content.Image
|
||||
data.B64Json = b64Json
|
||||
} else {
|
||||
data.B64Json = content.Image
|
||||
}
|
||||
} else if content.Text != "" {
|
||||
data.RevisedPrompt = content.Text
|
||||
}
|
||||
}
|
||||
imageData = append(imageData, data)
|
||||
}
|
||||
}
|
||||
|
||||
return imageData
|
||||
}
|
||||
|
||||
func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||
var imageData []dto.ImageData
|
||||
for _, data := range o.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
} else {
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageData = append(imageData, dto.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
}
|
||||
return imageData
|
||||
}
|
||||
|
||||
type AliResponse struct {
|
||||
@@ -92,18 +164,26 @@ type AliResponse struct {
|
||||
}
|
||||
|
||||
type AliImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters any `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
Parameters AliImageParameters `json:"parameters,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type AliImageParameters struct {
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Steps string `json:"steps,omitempty"`
|
||||
Scale string `json:"scale,omitempty"`
|
||||
Watermark *bool `json:"watermark,omitempty"`
|
||||
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||
}
|
||||
|
||||
func (p *AliImageParameters) PromptExtendValue() bool {
|
||||
if p != nil && p.PromptExtend != nil {
|
||||
return *p.PromptExtend
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AliImageInput struct {
|
||||
|
||||
@@ -21,17 +21,25 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
|
||||
var imageRequest AliImageRequest
|
||||
imageRequest.Model = request.Model
|
||||
imageRequest.ResponseFormat = request.ResponseFormat
|
||||
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||
logger.LogDebug(context.Background(), "oaiImage2Ali request isSync: "+fmt.Sprintf("%v", isSync))
|
||||
if request.Extra != nil {
|
||||
if val, ok := request.Extra["parameters"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 兼容没有parameters字段的情况,从openai标准字段中提取参数
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
}
|
||||
}
|
||||
if val, ok := request.Extra["input"]; ok {
|
||||
err := common.Unmarshal(val, &imageRequest.Input)
|
||||
@@ -41,23 +49,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Parameters == nil {
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||
N: int(request.N),
|
||||
Watermark: request.Watermark,
|
||||
if strings.Contains(request.Model, "z-image") {
|
||||
// z-image 开启prompt_extend后,按2倍计费
|
||||
if imageRequest.Parameters.PromptExtendValue() {
|
||||
info.PriceData.AddOtherRatio("prompt_extend", 2)
|
||||
}
|
||||
}
|
||||
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
// 检查n参数
|
||||
if imageRequest.Parameters.N != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
}
|
||||
|
||||
// 同步图片模型和异步图片模型请求格式不一样
|
||||
if isSync {
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Messages: []AliMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []AliMediaContent{
|
||||
{
|
||||
Text: request.Prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if imageRequest.Input == nil {
|
||||
imageRequest.Input = AliImageInput{
|
||||
Prompt: request.Prompt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
|
||||
mf := c.Request.MultipartForm
|
||||
if mf == nil {
|
||||
@@ -199,6 +228,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
|
||||
var taskResponse AliResponse
|
||||
var responseBody []byte
|
||||
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
|
||||
for {
|
||||
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
||||
step++
|
||||
@@ -238,32 +269,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
|
||||
Created: info.StartTime.Unix(),
|
||||
}
|
||||
|
||||
for _, data := range response.Output.Results {
|
||||
var b64Json string
|
||||
if responseFormat == "b64_json" {
|
||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||
if err != nil {
|
||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||
continue
|
||||
}
|
||||
b64Json = b64
|
||||
} else {
|
||||
b64Json = data.B64Image
|
||||
}
|
||||
|
||||
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
|
||||
Url: data.Url,
|
||||
B64Json: b64Json,
|
||||
RevisedPrompt: "",
|
||||
})
|
||||
if len(response.Output.Results) > 0 {
|
||||
imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
|
||||
} else if len(response.Output.Choices) > 0 {
|
||||
imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
|
||||
}
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(originBody, &mapResponse)
|
||||
imageResponse.Extra = mapResponse
|
||||
|
||||
imageResponse.Metadata = originBody
|
||||
return &imageResponse
|
||||
}
|
||||
|
||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
responseFormat := c.GetString("response_format")
|
||||
|
||||
var aliTaskResponse AliResponse
|
||||
@@ -282,66 +298,49 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
|
||||
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var (
|
||||
aliResponse *AliResponse
|
||||
originRespBody []byte
|
||||
)
|
||||
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
|
||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||
var aliResponse AliResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
service.CloseResponseBodyGracefully(resp)
|
||||
err = common.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Message != "" {
|
||||
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
||||
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
var fullTextResponse dto.ImageResponse
|
||||
if len(aliResponse.Output.Choices) > 0 {
|
||||
fullTextResponse = dto.ImageResponse{
|
||||
Created: info.StartTime.Unix(),
|
||||
Data: []dto.ImageData{
|
||||
{
|
||||
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
||||
B64Json: "",
|
||||
},
|
||||
},
|
||||
if a.IsSyncImageModel {
|
||||
aliResponse = &aliTaskResponse
|
||||
originRespBody = responseBody
|
||||
} else {
|
||||
// 异步图片模型需要轮询任务结果
|
||||
aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||
}
|
||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||
return types.WithOpenAIError(types.OpenAIError{
|
||||
Message: aliResponse.Output.Message,
|
||||
Type: "ali_error",
|
||||
Param: "",
|
||||
Code: aliResponse.Output.Code,
|
||||
}, resp.StatusCode), nil
|
||||
}
|
||||
}
|
||||
|
||||
var mapResponse map[string]any
|
||||
_ = common.Unmarshal(responseBody, &mapResponse)
|
||||
fullTextResponse.Extra = mapResponse
|
||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
||||
//logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
|
||||
if a.IsSyncImageModel {
|
||||
logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
|
||||
} else {
|
||||
logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
|
||||
}
|
||||
|
||||
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||
// 可能生成多张图片,修正计费数量n
|
||||
if aliResponse.Usage.ImageCount != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
|
||||
} else if len(imageResponses.Data) != 0 {
|
||||
info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
|
||||
}
|
||||
jsonResponse, err := common.Marshal(imageResponses)
|
||||
if err != nil {
|
||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||
|
||||
return nil, &dto.Usage{}
|
||||
}
|
||||
|
||||
@@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
|
||||
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
|
||||
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
|
||||
}
|
||||
wanParams := WanImageParameters{
|
||||
//wanParams := WanImageParameters{
|
||||
// N: int(request.N),
|
||||
//}
|
||||
imageRequest.Input = wanInput
|
||||
imageRequest.Parameters = AliImageParameters{
|
||||
N: int(request.N),
|
||||
}
|
||||
imageRequest.Input = wanInput
|
||||
imageRequest.Parameters = wanParams
|
||||
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||
|
||||
return &imageRequest, nil
|
||||
}
|
||||
|
||||
func isOldWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
|
||||
}
|
||||
|
||||
func isWanModel(modelName string) bool {
|
||||
return strings.Contains(modelName, "wan")
|
||||
}
|
||||
|
||||
@@ -184,19 +184,19 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
|
||||
if usage == nil {
|
||||
usage = &dto.Usage{
|
||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||
}
|
||||
extraContent += "(可能是请求出错)"
|
||||
extraContent = append(extraContent, "上游无计费信息")
|
||||
}
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
@@ -246,8 +246,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||
// search-preview 模型不支持 response api
|
||||
@@ -258,8 +258,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||
searchContextSize, dWebSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||
searchContextSize, dWebSearchQuota.String()))
|
||||
}
|
||||
// claude web search tool 计费
|
||||
var dClaudeWebSearchQuota decimal.Decimal
|
||||
@@ -269,8 +269,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
||||
extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
|
||||
}
|
||||
// file search tool 计费
|
||||
var dFileSearchQuota decimal.Decimal
|
||||
@@ -281,8 +281,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||
fileSearchTool.CallCount, dFileSearchQuota.String()))
|
||||
}
|
||||
}
|
||||
var dImageGenerationCallQuota decimal.Decimal
|
||||
@@ -290,7 +290,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
if ctx.GetBool("image_generation_call") {
|
||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
|
||||
}
|
||||
|
||||
var quotaCalculateDecimal decimal.Decimal
|
||||
@@ -331,7 +331,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
// 重新计算 base tokens
|
||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
||||
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
|
||||
}
|
||||
}
|
||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||
@@ -356,17 +356,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
// 添加 image generation call 计费
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||
|
||||
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||
dOtherRatio := decimal.NewFromFloat(otherRatio)
|
||||
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
|
||||
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
|
||||
}
|
||||
}
|
||||
|
||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||
totalTokens := promptTokens + completionTokens
|
||||
|
||||
var logContent string
|
||||
//var logContent string
|
||||
|
||||
// record all the consume log even if quota is 0
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
@@ -405,15 +413,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
||||
logModel := modelName
|
||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||
logModel = "gpt-4-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||
logModel = "gpt-4o-gizmo-*"
|
||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
||||
}
|
||||
if extraContent != "" {
|
||||
logContent += ", " + extraContent
|
||||
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||
}
|
||||
logContent := strings.Join(extraContent, ", ")
|
||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
if imageTokens != 0 {
|
||||
other["image"] = true
|
||||
|
||||
@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -193,7 +193,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -292,6 +292,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
||||
return openaiErr
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -124,12 +124,18 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
||||
quality = "hd"
|
||||
}
|
||||
|
||||
var logContent string
|
||||
var logContent []string
|
||||
|
||||
if len(request.Size) > 0 {
|
||||
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
|
||||
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
|
||||
}
|
||||
if len(quality) > 0 {
|
||||
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
|
||||
}
|
||||
if request.N > 0 {
|
||||
logContent = append(logContent, fmt.Sprintf("生成数量 %d", request.N))
|
||||
}
|
||||
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -95,6 +95,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||
return newAPIError
|
||||
}
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
} else {
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -26,12 +26,22 @@ type PriceData struct {
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p *PriceData) AddOtherRatio(key string, ratio float64) {
|
||||
if p.OtherRatios == nil {
|
||||
p.OtherRatios = make(map[string]float64)
|
||||
}
|
||||
if ratio <= 0 {
|
||||
return
|
||||
}
|
||||
p.OtherRatios[key] = ratio
|
||||
}
|
||||
|
||||
type PerCallPriceData struct {
|
||||
ModelPrice float64
|
||||
Quota int
|
||||
GroupRatioInfo GroupRatioInfo
|
||||
}
|
||||
|
||||
func (p PriceData) ToSetting() string {
|
||||
func (p *PriceData) ToSetting() string {
|
||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user