mirror of
https://github.com/QuantumNous/new-api.git
synced 2026-03-30 02:05:21 +00:00
feat(adaptor): 新适配百炼多种图片生成模型
- wan2.6系列生图与编辑,适配多图生成计费 - wan2.5系列生图与编辑 - z-image-turbo生图,适配prompt_extend计费
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -149,6 +150,24 @@ func AddToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 非无限额度时,检查额度值是否超出有效范围
|
||||||
|
if !token.UnlimitedQuota {
|
||||||
|
if token.RemainQuota < 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "额度值不能为负数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
|
||||||
|
if token.RemainQuota > maxQuotaValue {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
key, err := common.GenerateKey()
|
key, err := common.GenerateKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
@@ -216,6 +235,23 @@ func UpdateToken(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !token.UnlimitedQuota {
|
||||||
|
if token.RemainQuota < 0 {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": "额度值不能为负数",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
maxQuotaValue := int((1000000000 * common.QuotaPerUnit))
|
||||||
|
if token.RemainQuota > maxQuotaValue {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": fmt.Sprintf("额度值超出有效范围,最大值为 %d", maxQuotaValue),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
cleanToken, err := model.GetTokenByIds(token.Id, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@@ -261,7 +297,6 @@ func UpdateToken(c *gin.Context) {
|
|||||||
"message": "",
|
"message": "",
|
||||||
"data": cleanToken,
|
"data": cleanToken,
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenBatch struct {
|
type TokenBatch struct {
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ func (i *ImageRequest) SetModelName(modelName string) {
|
|||||||
type ImageResponse struct {
|
type ImageResponse struct {
|
||||||
Data []ImageData `json:"data"`
|
Data []ImageData `json:"data"`
|
||||||
Created int64 `json:"created"`
|
Created int64 `json:"created"`
|
||||||
Extra any `json:"extra,omitempty"`
|
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||||
}
|
}
|
||||||
type ImageData struct {
|
type ImageData struct {
|
||||||
Url string `json:"url"`
|
Url string `json:"url"`
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ type FormatJsonSchema struct {
|
|||||||
Strict json.RawMessage `json:"strict,omitempty"`
|
Strict json.RawMessage `json:"strict,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GeneralOpenAIRequest represents a general request structure for OpenAI-compatible APIs.
|
||||||
|
// 参数增加规范:无引用的参数必须使用json.RawMessage类型,并添加omitempty标签
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
@@ -82,8 +84,9 @@ type GeneralOpenAIRequest struct {
|
|||||||
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
Reasoning json.RawMessage `json:"reasoning,omitempty"`
|
||||||
// Ali Qwen Params
|
// Ali Qwen Params
|
||||||
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
VlHighResolutionImages json.RawMessage `json:"vl_high_resolution_images,omitempty"`
|
||||||
EnableThinking any `json:"enable_thinking,omitempty"`
|
EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
|
||||||
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
|
ChatTemplateKwargs json.RawMessage `json:"chat_template_kwargs,omitempty"`
|
||||||
|
EnableSearch json.RawMessage `json:"enable_search,omitempty"`
|
||||||
// ollama Params
|
// ollama Params
|
||||||
Think json.RawMessage `json:"think,omitempty"`
|
Think json.RawMessage `json:"think,omitempty"`
|
||||||
// baidu v2
|
// baidu v2
|
||||||
|
|||||||
2
main.go
2
main.go
@@ -188,6 +188,7 @@ func InjectUmamiAnalytics() {
|
|||||||
analyticsInjectBuilder.WriteString(umamiSiteID)
|
analyticsInjectBuilder.WriteString(umamiSiteID)
|
||||||
analyticsInjectBuilder.WriteString("\"></script>")
|
analyticsInjectBuilder.WriteString("\"></script>")
|
||||||
}
|
}
|
||||||
|
analyticsInjectBuilder.WriteString("<!--Umami QuantumNous-->\n")
|
||||||
analyticsInject := analyticsInjectBuilder.String()
|
analyticsInject := analyticsInjectBuilder.String()
|
||||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
|
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--umami-->\n"), []byte(analyticsInject))
|
||||||
}
|
}
|
||||||
@@ -209,6 +210,7 @@ func InjectGoogleAnalytics() {
|
|||||||
analyticsInjectBuilder.WriteString("');")
|
analyticsInjectBuilder.WriteString("');")
|
||||||
analyticsInjectBuilder.WriteString("</script>")
|
analyticsInjectBuilder.WriteString("</script>")
|
||||||
}
|
}
|
||||||
|
analyticsInjectBuilder.WriteString("<!--Google Analytics QuantumNous-->\n")
|
||||||
analyticsInject := analyticsInjectBuilder.String()
|
analyticsInject := analyticsInjectBuilder.String()
|
||||||
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
|
indexPage = bytes.ReplaceAll(indexPage, []byte("<!--Google Analytics-->\n"), []byte(analyticsInject))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||||
} else {
|
} else {
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -19,6 +19,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
|
IsSyncImageModel bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var syncModels = []string{
|
||||||
|
"z-image",
|
||||||
|
"qwen-image",
|
||||||
|
"wan2.6",
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSyncImageModel(modelName string) bool {
|
||||||
|
for _, m := range syncModels {
|
||||||
|
if strings.Contains(modelName, m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
||||||
@@ -45,10 +61,16 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
|
if isSyncImageModel(info.OriginModelName) {
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||||
|
} else {
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
|
||||||
|
}
|
||||||
case constant.RelayModeImagesEdits:
|
case constant.RelayModeImagesEdits:
|
||||||
if isWanModel(info.OriginModelName) {
|
if isOldWanModel(info.OriginModelName) {
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image2image/image-synthesis", info.ChannelBaseUrl)
|
||||||
|
} else if isWanModel(info.OriginModelName) {
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/image-generation/generation", info.ChannelBaseUrl)
|
||||||
} else {
|
} else {
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/multimodal-generation/generation", info.ChannelBaseUrl)
|
||||||
}
|
}
|
||||||
@@ -72,8 +94,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
|
|||||||
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||||
}
|
}
|
||||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||||
|
if isSyncImageModel(info.OriginModelName) {
|
||||||
|
|
||||||
|
} else {
|
||||||
req.Set("X-DashScope-Async", "enable")
|
req.Set("X-DashScope-Async", "enable")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if info.RelayMode == constant.RelayModeImagesEdits {
|
if info.RelayMode == constant.RelayModeImagesEdits {
|
||||||
if isWanModel(info.OriginModelName) {
|
if isWanModel(info.OriginModelName) {
|
||||||
req.Set("X-DashScope-Async", "enable")
|
req.Set("X-DashScope-Async", "enable")
|
||||||
@@ -108,15 +134,25 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||||
if info.RelayMode == constant.RelayModeImagesGenerations {
|
if info.RelayMode == constant.RelayModeImagesGenerations {
|
||||||
aliRequest, err := oaiImage2Ali(request)
|
if isSyncImageModel(info.OriginModelName) {
|
||||||
|
a.IsSyncImageModel = true
|
||||||
|
}
|
||||||
|
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||||
}
|
}
|
||||||
return aliRequest, nil
|
return aliRequest, nil
|
||||||
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
} else if info.RelayMode == constant.RelayModeImagesEdits {
|
||||||
if isWanModel(info.OriginModelName) {
|
if isOldWanModel(info.OriginModelName) {
|
||||||
return oaiFormEdit2WanxImageEdit(c, info, request)
|
return oaiFormEdit2WanxImageEdit(c, info, request)
|
||||||
}
|
}
|
||||||
|
if isSyncImageModel(info.OriginModelName) {
|
||||||
|
if isWanModel(info.OriginModelName) {
|
||||||
|
a.IsSyncImageModel = false
|
||||||
|
} else {
|
||||||
|
a.IsSyncImageModel = true
|
||||||
|
}
|
||||||
|
}
|
||||||
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
// ali image edit https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2976416
|
||||||
// 如果用户使用表单,则需要解析表单数据
|
// 如果用户使用表单,则需要解析表单数据
|
||||||
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
if strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") {
|
||||||
@@ -126,9 +162,9 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
return aliRequest, nil
|
return aliRequest, nil
|
||||||
} else {
|
} else {
|
||||||
aliRequest, err := oaiImage2Ali(request)
|
aliRequest, err := oaiImage2AliImageRequest(info, request, a.IsSyncImageModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("convert image request failed: %w", err)
|
return nil, fmt.Errorf("convert image request to async ali image request failed: %w", err)
|
||||||
}
|
}
|
||||||
return aliRequest, nil
|
return aliRequest, nil
|
||||||
}
|
}
|
||||||
@@ -169,13 +205,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
default:
|
default:
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
err, usage = aliImageHandler(c, resp, info)
|
err, usage = aliImageHandler(a, c, resp, info)
|
||||||
case constant.RelayModeImagesEdits:
|
case constant.RelayModeImagesEdits:
|
||||||
if isWanModel(info.OriginModelName) {
|
err, usage = aliImageHandler(a, c, resp, info)
|
||||||
err, usage = aliImageHandler(c, resp, info)
|
|
||||||
} else {
|
|
||||||
err, usage = aliImageEditHandler(c, resp, info)
|
|
||||||
}
|
|
||||||
case constant.RelayModeRerank:
|
case constant.RelayModeRerank:
|
||||||
err, usage = RerankHandler(c, resp, info)
|
err, usage = RerankHandler(c, resp, info)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
package ali
|
package ali
|
||||||
|
|
||||||
import "github.com/QuantumNous/new-api/dto"
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/QuantumNous/new-api/dto"
|
||||||
|
"github.com/QuantumNous/new-api/logger"
|
||||||
|
"github.com/QuantumNous/new-api/service"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
type AliMessage struct {
|
type AliMessage struct {
|
||||||
Content any `json:"content"`
|
Content any `json:"content"`
|
||||||
@@ -65,6 +72,7 @@ type AliUsage struct {
|
|||||||
InputTokens int `json:"input_tokens"`
|
InputTokens int `json:"input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
TotalTokens int `json:"total_tokens"`
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
ImageCount int `json:"image_count,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TaskResult struct {
|
type TaskResult struct {
|
||||||
@@ -82,7 +90,71 @@ type AliOutput struct {
|
|||||||
Message string `json:"message,omitempty"`
|
Message string `json:"message,omitempty"`
|
||||||
Code string `json:"code,omitempty"`
|
Code string `json:"code,omitempty"`
|
||||||
Results []TaskResult `json:"results,omitempty"`
|
Results []TaskResult `json:"results,omitempty"`
|
||||||
Choices []map[string]any `json:"choices,omitempty"`
|
Choices []struct {
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
Message struct {
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content []AliMediaContent `json:"content,omitempty"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
|
} `json:"message,omitempty"`
|
||||||
|
} `json:"choices,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *AliOutput) ChoicesToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||||
|
var imageData []dto.ImageData
|
||||||
|
if len(o.Choices) > 0 {
|
||||||
|
for _, choice := range o.Choices {
|
||||||
|
var data dto.ImageData
|
||||||
|
for _, content := range choice.Message.Content {
|
||||||
|
if content.Image != "" {
|
||||||
|
if strings.HasPrefix(content.Image, "http") {
|
||||||
|
var b64Json string
|
||||||
|
if responseFormat == "b64_json" {
|
||||||
|
_, b64, err := service.GetImageFromUrl(content.Image)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b64Json = b64
|
||||||
|
}
|
||||||
|
data.Url = content.Image
|
||||||
|
data.B64Json = b64Json
|
||||||
|
} else {
|
||||||
|
data.B64Json = content.Image
|
||||||
|
}
|
||||||
|
} else if content.Text != "" {
|
||||||
|
data.RevisedPrompt = content.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
imageData = append(imageData, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return imageData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *AliOutput) ResultToOpenAIImageDate(c *gin.Context, responseFormat string) []dto.ImageData {
|
||||||
|
var imageData []dto.ImageData
|
||||||
|
for _, data := range o.Results {
|
||||||
|
var b64Json string
|
||||||
|
if responseFormat == "b64_json" {
|
||||||
|
_, b64, err := service.GetImageFromUrl(data.Url)
|
||||||
|
if err != nil {
|
||||||
|
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b64Json = b64
|
||||||
|
} else {
|
||||||
|
b64Json = data.B64Image
|
||||||
|
}
|
||||||
|
|
||||||
|
imageData = append(imageData, dto.ImageData{
|
||||||
|
Url: data.Url,
|
||||||
|
B64Json: b64Json,
|
||||||
|
RevisedPrompt: "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return imageData
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliResponse struct {
|
type AliResponse struct {
|
||||||
@@ -94,7 +166,7 @@ type AliResponse struct {
|
|||||||
type AliImageRequest struct {
|
type AliImageRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Input any `json:"input"`
|
Input any `json:"input"`
|
||||||
Parameters any `json:"parameters,omitempty"`
|
Parameters AliImageParameters `json:"parameters,omitempty"`
|
||||||
ResponseFormat string `json:"response_format,omitempty"`
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,6 +176,14 @@ type AliImageParameters struct {
|
|||||||
Steps string `json:"steps,omitempty"`
|
Steps string `json:"steps,omitempty"`
|
||||||
Scale string `json:"scale,omitempty"`
|
Scale string `json:"scale,omitempty"`
|
||||||
Watermark *bool `json:"watermark,omitempty"`
|
Watermark *bool `json:"watermark,omitempty"`
|
||||||
|
PromptExtend *bool `json:"prompt_extend,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *AliImageParameters) PromptExtendValue() bool {
|
||||||
|
if p != nil && p.PromptExtend != nil {
|
||||||
|
return *p.PromptExtend
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
type AliImageInput struct {
|
type AliImageInput struct {
|
||||||
|
|||||||
@@ -21,17 +21,25 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
func oaiImage2AliImageRequest(info *relaycommon.RelayInfo, request dto.ImageRequest, isSync bool) (*AliImageRequest, error) {
|
||||||
var imageRequest AliImageRequest
|
var imageRequest AliImageRequest
|
||||||
imageRequest.Model = request.Model
|
imageRequest.Model = request.Model
|
||||||
imageRequest.ResponseFormat = request.ResponseFormat
|
imageRequest.ResponseFormat = request.ResponseFormat
|
||||||
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
logger.LogJson(context.Background(), "oaiImage2Ali request extra", request.Extra)
|
||||||
|
logger.LogDebug(context.Background(), "oaiImage2Ali request isSync: "+fmt.Sprintf("%v", isSync))
|
||||||
if request.Extra != nil {
|
if request.Extra != nil {
|
||||||
if val, ok := request.Extra["parameters"]; ok {
|
if val, ok := request.Extra["parameters"]; ok {
|
||||||
err := common.Unmarshal(val, &imageRequest.Parameters)
|
err := common.Unmarshal(val, &imageRequest.Parameters)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
return nil, fmt.Errorf("invalid parameters field: %w", err)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// 兼容没有parameters字段的情况,从openai标准字段中提取参数
|
||||||
|
imageRequest.Parameters = AliImageParameters{
|
||||||
|
Size: strings.Replace(request.Size, "x", "*", -1),
|
||||||
|
N: int(request.N),
|
||||||
|
Watermark: request.Watermark,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if val, ok := request.Extra["input"]; ok {
|
if val, ok := request.Extra["input"]; ok {
|
||||||
err := common.Unmarshal(val, &imageRequest.Input)
|
err := common.Unmarshal(val, &imageRequest.Input)
|
||||||
@@ -41,23 +49,44 @@ func oaiImage2Ali(request dto.ImageRequest) (*AliImageRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if imageRequest.Parameters == nil {
|
if strings.Contains(request.Model, "z-image") {
|
||||||
imageRequest.Parameters = AliImageParameters{
|
// z-image 开启prompt_extend后,按2倍计费
|
||||||
Size: strings.Replace(request.Size, "x", "*", -1),
|
if imageRequest.Parameters.PromptExtendValue() {
|
||||||
N: int(request.N),
|
info.PriceData.AddOtherRatio("prompt_extend", 2)
|
||||||
Watermark: request.Watermark,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查n参数
|
||||||
|
if imageRequest.Parameters.N != 0 {
|
||||||
|
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 同步图片模型和异步图片模型请求格式不一样
|
||||||
|
if isSync {
|
||||||
|
if imageRequest.Input == nil {
|
||||||
|
imageRequest.Input = AliImageInput{
|
||||||
|
Messages: []AliMessage{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: []AliMediaContent{
|
||||||
|
{
|
||||||
|
Text: request.Prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if imageRequest.Input == nil {
|
if imageRequest.Input == nil {
|
||||||
imageRequest.Input = AliImageInput{
|
imageRequest.Input = AliImageInput{
|
||||||
Prompt: request.Prompt,
|
Prompt: request.Prompt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &imageRequest, nil
|
return &imageRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
|
func getImageBase64sFromForm(c *gin.Context, fieldName string) ([]string, error) {
|
||||||
mf := c.Request.MultipartForm
|
mf := c.Request.MultipartForm
|
||||||
if mf == nil {
|
if mf == nil {
|
||||||
@@ -199,6 +228,8 @@ func asyncTaskWait(c *gin.Context, info *relaycommon.RelayInfo, taskID string) (
|
|||||||
var taskResponse AliResponse
|
var taskResponse AliResponse
|
||||||
var responseBody []byte
|
var responseBody []byte
|
||||||
|
|
||||||
|
time.Sleep(time.Duration(5) * time.Second)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
logger.LogDebug(c, fmt.Sprintf("asyncTaskWait step %d/%d, wait %d seconds", step, maxStep, waitSeconds))
|
||||||
step++
|
step++
|
||||||
@@ -238,32 +269,17 @@ func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, originBody [
|
|||||||
Created: info.StartTime.Unix(),
|
Created: info.StartTime.Unix(),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, data := range response.Output.Results {
|
if len(response.Output.Results) > 0 {
|
||||||
var b64Json string
|
imageResponse.Data = response.Output.ResultToOpenAIImageDate(c, responseFormat)
|
||||||
if responseFormat == "b64_json" {
|
} else if len(response.Output.Choices) > 0 {
|
||||||
_, b64, err := service.GetImageFromUrl(data.Url)
|
imageResponse.Data = response.Output.ChoicesToOpenAIImageDate(c, responseFormat)
|
||||||
if err != nil {
|
|
||||||
logger.LogError(c, "get_image_data_failed: "+err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b64Json = b64
|
|
||||||
} else {
|
|
||||||
b64Json = data.B64Image
|
|
||||||
}
|
}
|
||||||
|
|
||||||
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
|
imageResponse.Metadata = originBody
|
||||||
Url: data.Url,
|
|
||||||
B64Json: b64Json,
|
|
||||||
RevisedPrompt: "",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
var mapResponse map[string]any
|
|
||||||
_ = common.Unmarshal(originBody, &mapResponse)
|
|
||||||
imageResponse.Extra = mapResponse
|
|
||||||
return &imageResponse
|
return &imageResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
func aliImageHandler(a *Adaptor, c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
||||||
responseFormat := c.GetString("response_format")
|
responseFormat := c.GetString("response_format")
|
||||||
|
|
||||||
var aliTaskResponse AliResponse
|
var aliTaskResponse AliResponse
|
||||||
@@ -282,11 +298,20 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
aliResponse, originRespBody, err := asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
var (
|
||||||
|
aliResponse *AliResponse
|
||||||
|
originRespBody []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
if a.IsSyncImageModel {
|
||||||
|
aliResponse = &aliTaskResponse
|
||||||
|
originRespBody = responseBody
|
||||||
|
} else {
|
||||||
|
// 异步图片模型需要轮询任务结果
|
||||||
|
aliResponse, originRespBody, err = asyncTaskWait(c, info, aliTaskResponse.Output.TaskId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
return types.NewError(err, types.ErrorCodeBadResponse), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
|
||||||
return types.WithOpenAIError(types.OpenAIError{
|
return types.WithOpenAIError(types.OpenAIError{
|
||||||
Message: aliResponse.Output.Message,
|
Message: aliResponse.Output.Message,
|
||||||
@@ -295,53 +320,27 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
Code: aliResponse.Output.Code,
|
Code: aliResponse.Output.Code,
|
||||||
}, resp.StatusCode), nil
|
}, resp.StatusCode), nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
//logger.LogDebug(c, "ali_async_task_result: "+string(originRespBody))
|
||||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
if a.IsSyncImageModel {
|
||||||
if err != nil {
|
logger.LogDebug(c, "ali_sync_image_result: "+string(originRespBody))
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
} else {
|
||||||
}
|
logger.LogDebug(c, "ali_async_image_result: "+string(originRespBody))
|
||||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
}
|
||||||
return nil, &dto.Usage{}
|
|
||||||
}
|
imageResponses := responseAli2OpenAIImage(c, aliResponse, originRespBody, info, responseFormat)
|
||||||
|
// 可能生成多张图片,修正计费数量n
|
||||||
func aliImageEditHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
if aliResponse.Usage.ImageCount != 0 {
|
||||||
var aliResponse AliResponse
|
info.PriceData.AddOtherRatio("n", float64(aliResponse.Usage.ImageCount))
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
} else if len(imageResponses.Data) != 0 {
|
||||||
if err != nil {
|
info.PriceData.AddOtherRatio("n", float64(len(imageResponses.Data)))
|
||||||
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
|
}
|
||||||
}
|
jsonResponse, err := common.Marshal(imageResponses)
|
||||||
|
|
||||||
service.CloseResponseBodyGracefully(resp)
|
|
||||||
err = common.Unmarshal(responseBody, &aliResponse)
|
|
||||||
if err != nil {
|
|
||||||
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if aliResponse.Message != "" {
|
|
||||||
logger.LogError(c, "ali_task_failed: "+aliResponse.Message)
|
|
||||||
return types.NewError(errors.New(aliResponse.Message), types.ErrorCodeBadResponse), nil
|
|
||||||
}
|
|
||||||
var fullTextResponse dto.ImageResponse
|
|
||||||
if len(aliResponse.Output.Choices) > 0 {
|
|
||||||
fullTextResponse = dto.ImageResponse{
|
|
||||||
Created: info.StartTime.Unix(),
|
|
||||||
Data: []dto.ImageData{
|
|
||||||
{
|
|
||||||
Url: aliResponse.Output.Choices[0]["message"].(map[string]any)["content"].([]any)[0].(map[string]any)["image"].(string),
|
|
||||||
B64Json: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var mapResponse map[string]any
|
|
||||||
_ = common.Unmarshal(responseBody, &mapResponse)
|
|
||||||
fullTextResponse.Extra = mapResponse
|
|
||||||
jsonResponse, err := common.Marshal(fullTextResponse)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
|
||||||
}
|
}
|
||||||
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
service.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||||||
|
|
||||||
return nil, &dto.Usage{}
|
return nil, &dto.Usage{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,14 +26,22 @@ func oaiFormEdit2WanxImageEdit(c *gin.Context, info *relaycommon.RelayInfo, requ
|
|||||||
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
|
if wanInput.Images, err = getImageBase64sFromForm(c, "image"); err != nil {
|
||||||
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
|
return nil, fmt.Errorf("get image base64s from form failed: %w", err)
|
||||||
}
|
}
|
||||||
wanParams := WanImageParameters{
|
//wanParams := WanImageParameters{
|
||||||
|
// N: int(request.N),
|
||||||
|
//}
|
||||||
|
imageRequest.Input = wanInput
|
||||||
|
imageRequest.Parameters = AliImageParameters{
|
||||||
N: int(request.N),
|
N: int(request.N),
|
||||||
}
|
}
|
||||||
imageRequest.Input = wanInput
|
info.PriceData.AddOtherRatio("n", float64(imageRequest.Parameters.N))
|
||||||
imageRequest.Parameters = wanParams
|
|
||||||
return &imageRequest, nil
|
return &imageRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isOldWanModel(modelName string) bool {
|
||||||
|
return strings.Contains(modelName, "wan") && !strings.Contains(modelName, "wan2.6")
|
||||||
|
}
|
||||||
|
|
||||||
func isWanModel(modelName string) bool {
|
func isWanModel(modelName string) bool {
|
||||||
return strings.Contains(modelName, "wan")
|
return strings.Contains(modelName, "wan")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -184,19 +184,19 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
|
|||||||
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
if usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 {
|
||||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||||
} else {
|
} else {
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent ...string) {
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
usage = &dto.Usage{
|
usage = &dto.Usage{
|
||||||
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
PromptTokens: relayInfo.GetEstimatePromptTokens(),
|
||||||
CompletionTokens: 0,
|
CompletionTokens: 0,
|
||||||
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
TotalTokens: relayInfo.GetEstimatePromptTokens(),
|
||||||
}
|
}
|
||||||
extraContent += "(可能是请求出错)"
|
extraContent = append(extraContent, "上游无计费信息")
|
||||||
}
|
}
|
||||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||||
promptTokens := usage.PromptTokens
|
promptTokens := usage.PromptTokens
|
||||||
@@ -246,8 +246,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||||
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
Mul(decimal.NewFromInt(int64(webSearchTool.CallCount))).
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||||
extraContent += fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 %d 次,上下文大小 %s,调用花费 %s",
|
||||||
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String())
|
webSearchTool.CallCount, webSearchTool.SearchContextSize, dWebSearchQuota.String()))
|
||||||
}
|
}
|
||||||
} else if strings.HasSuffix(modelName, "search-preview") {
|
} else if strings.HasSuffix(modelName, "search-preview") {
|
||||||
// search-preview 模型不支持 response api
|
// search-preview 模型不支持 response api
|
||||||
@@ -258,8 +258,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
webSearchPrice = operation_setting.GetWebSearchPricePerThousand(modelName, searchContextSize)
|
||||||
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
dWebSearchQuota = decimal.NewFromFloat(webSearchPrice).
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||||
extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
extraContent = append(extraContent, fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
|
||||||
searchContextSize, dWebSearchQuota.String())
|
searchContextSize, dWebSearchQuota.String()))
|
||||||
}
|
}
|
||||||
// claude web search tool 计费
|
// claude web search tool 计费
|
||||||
var dClaudeWebSearchQuota decimal.Decimal
|
var dClaudeWebSearchQuota decimal.Decimal
|
||||||
@@ -269,8 +269,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
|
||||||
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
|
||||||
extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
extraContent = append(extraContent, fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
|
||||||
claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
|
claudeWebSearchCallCount, dClaudeWebSearchQuota.String()))
|
||||||
}
|
}
|
||||||
// file search tool 计费
|
// file search tool 计费
|
||||||
var dFileSearchQuota decimal.Decimal
|
var dFileSearchQuota decimal.Decimal
|
||||||
@@ -281,8 +281,8 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
dFileSearchQuota = decimal.NewFromFloat(fileSearchPrice).
|
||||||
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
Mul(decimal.NewFromInt(int64(fileSearchTool.CallCount))).
|
||||||
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||||
extraContent += fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
extraContent = append(extraContent, fmt.Sprintf("File Search 调用 %d 次,调用花费 %s",
|
||||||
fileSearchTool.CallCount, dFileSearchQuota.String())
|
fileSearchTool.CallCount, dFileSearchQuota.String()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var dImageGenerationCallQuota decimal.Decimal
|
var dImageGenerationCallQuota decimal.Decimal
|
||||||
@@ -290,7 +290,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
if ctx.GetBool("image_generation_call") {
|
if ctx.GetBool("image_generation_call") {
|
||||||
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
imageGenerationCallPrice = operation_setting.GetGPTImage1PriceOnceCall(ctx.GetString("image_generation_call_quality"), ctx.GetString("image_generation_call_size"))
|
||||||
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
dImageGenerationCallQuota = decimal.NewFromFloat(imageGenerationCallPrice).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||||
extraContent += fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String())
|
extraContent = append(extraContent, fmt.Sprintf("Image Generation Call 花费 %s", dImageGenerationCallQuota.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
var quotaCalculateDecimal decimal.Decimal
|
var quotaCalculateDecimal decimal.Decimal
|
||||||
@@ -331,7 +331,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
// 重新计算 base tokens
|
// 重新计算 base tokens
|
||||||
baseTokens = baseTokens.Sub(dAudioTokens)
|
baseTokens = baseTokens.Sub(dAudioTokens)
|
||||||
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
audioInputQuota = decimal.NewFromFloat(audioInputPrice).Div(decimal.NewFromInt(1000000)).Mul(dAudioTokens).Mul(dGroupRatio).Mul(dQuotaPerUnit)
|
||||||
extraContent += fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String())
|
extraContent = append(extraContent, fmt.Sprintf("Audio Input 花费 %s", audioInputQuota.String()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
promptQuota := baseTokens.Add(cachedTokensWithRatio).
|
||||||
@@ -356,17 +356,25 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
// 添加 image generation call 计费
|
// 添加 image generation call 计费
|
||||||
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
quotaCalculateDecimal = quotaCalculateDecimal.Add(dImageGenerationCallQuota)
|
||||||
|
|
||||||
|
if len(relayInfo.PriceData.OtherRatios) > 0 {
|
||||||
|
for key, otherRatio := range relayInfo.PriceData.OtherRatios {
|
||||||
|
dOtherRatio := decimal.NewFromFloat(otherRatio)
|
||||||
|
quotaCalculateDecimal = quotaCalculateDecimal.Mul(dOtherRatio)
|
||||||
|
extraContent = append(extraContent, fmt.Sprintf("其他倍率 %s: %f", key, otherRatio))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
quota := int(quotaCalculateDecimal.Round(0).IntPart())
|
||||||
totalTokens := promptTokens + completionTokens
|
totalTokens := promptTokens + completionTokens
|
||||||
|
|
||||||
var logContent string
|
//var logContent string
|
||||||
|
|
||||||
// record all the consume log even if quota is 0
|
// record all the consume log even if quota is 0
|
||||||
if totalTokens == 0 {
|
if totalTokens == 0 {
|
||||||
// in this case, must be some error happened
|
// in this case, must be some error happened
|
||||||
// we cannot just return, because we may have to return the pre-consumed quota
|
// we cannot just return, because we may have to return the pre-consumed quota
|
||||||
quota = 0
|
quota = 0
|
||||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
extraContent = append(extraContent, "上游没有返回计费信息,无法扣费(可能是上游超时)")
|
||||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||||
} else {
|
} else {
|
||||||
@@ -405,15 +413,13 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage
|
|||||||
logModel := modelName
|
logModel := modelName
|
||||||
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
if strings.HasPrefix(logModel, "gpt-4-gizmo") {
|
||||||
logModel = "gpt-4-gizmo-*"
|
logModel = "gpt-4-gizmo-*"
|
||||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
if strings.HasPrefix(logModel, "gpt-4o-gizmo") {
|
||||||
logModel = "gpt-4o-gizmo-*"
|
logModel = "gpt-4o-gizmo-*"
|
||||||
logContent += fmt.Sprintf(",模型 %s", modelName)
|
extraContent = append(extraContent, fmt.Sprintf("模型 %s", modelName))
|
||||||
}
|
|
||||||
if extraContent != "" {
|
|
||||||
logContent += ", " + extraContent
|
|
||||||
}
|
}
|
||||||
|
logContent := strings.Join(extraContent, ", ")
|
||||||
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, cacheTokens, cacheRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||||
if imageTokens != 0 {
|
if imageTokens != 0 {
|
||||||
other["image"] = true
|
other["image"] = true
|
||||||
|
|||||||
@@ -82,6 +82,6 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,6 +292,6 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo) (newAPI
|
|||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
|
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -124,12 +124,18 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
|
|||||||
quality = "hd"
|
quality = "hd"
|
||||||
}
|
}
|
||||||
|
|
||||||
var logContent string
|
var logContent []string
|
||||||
|
|
||||||
if len(request.Size) > 0 {
|
if len(request.Size) > 0 {
|
||||||
logContent = fmt.Sprintf("大小 %s, 品质 %s, 张数 %d", request.Size, quality, request.N)
|
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
|
||||||
|
}
|
||||||
|
if len(quality) > 0 {
|
||||||
|
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
|
||||||
|
}
|
||||||
|
if request.N > 0 {
|
||||||
|
logContent = append(logContent, fmt.Sprintf("生成数量 %d", request.N))
|
||||||
}
|
}
|
||||||
|
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), logContent)
|
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,6 +95,6 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
|
|||||||
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
|
||||||
return newAPIError
|
return newAPIError
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
|
|||||||
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
if strings.HasPrefix(info.OriginModelName, "gpt-4o-audio") {
|
||||||
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
|
||||||
} else {
|
} else {
|
||||||
postConsumeQuota(c, info, usage.(*dto.Usage), "")
|
postConsumeQuota(c, info, usage.(*dto.Usage))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,12 +26,22 @@ type PriceData struct {
|
|||||||
GroupRatioInfo GroupRatioInfo
|
GroupRatioInfo GroupRatioInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PriceData) AddOtherRatio(key string, ratio float64) {
|
||||||
|
if p.OtherRatios == nil {
|
||||||
|
p.OtherRatios = make(map[string]float64)
|
||||||
|
}
|
||||||
|
if ratio <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.OtherRatios[key] = ratio
|
||||||
|
}
|
||||||
|
|
||||||
type PerCallPriceData struct {
|
type PerCallPriceData struct {
|
||||||
ModelPrice float64
|
ModelPrice float64
|
||||||
Quota int
|
Quota int
|
||||||
GroupRatioInfo GroupRatioInfo
|
GroupRatioInfo GroupRatioInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p PriceData) ToSetting() string {
|
func (p *PriceData) ToSetting() string {
|
||||||
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
|
return fmt.Sprintf("ModelPrice: %f, ModelRatio: %f, CompletionRatio: %f, CacheRatio: %f, GroupRatio: %f, UsePrice: %t, CacheCreationRatio: %f, CacheCreation5mRatio: %f, CacheCreation1hRatio: %f, QuotaToPreConsume: %d, ImageRatio: %f, AudioRatio: %f, AudioCompletionRatio: %f", p.ModelPrice, p.ModelRatio, p.CompletionRatio, p.CacheRatio, p.GroupRatioInfo.GroupRatio, p.UsePrice, p.CacheCreationRatio, p.CacheCreation5mRatio, p.CacheCreation1hRatio, p.QuotaToPreConsume, p.ImageRatio, p.AudioRatio, p.AudioCompletionRatio)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user