Files
new-api/dto/openai_image.go
2025-08-25 14:33:12 +08:00

148 lines
3.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package dto
import (
"encoding/json"
"one-api/common"
"one-api/types"
"reflect"
"strings"
"github.com/gin-gonic/gin"
)
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N uint `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style json.RawMessage `json:"style,omitempty"`
User json.RawMessage `json:"user,omitempty"`
ExtraFields json.RawMessage `json:"extra_fields,omitempty"`
Background json.RawMessage `json:"background,omitempty"`
Moderation json.RawMessage `json:"moderation,omitempty"`
OutputFormat json.RawMessage `json:"output_format,omitempty"`
OutputCompression json.RawMessage `json:"output_compression,omitempty"`
PartialImages json.RawMessage `json:"partial_images,omitempty"`
// Stream bool `json:"stream,omitempty"`
Watermark *bool `json:"watermark,omitempty"`
// 用匿名参数接收额外参数
Extra map[string]json.RawMessage `json:"-"`
}
func (i *ImageRequest) UnmarshalJSON(data []byte) error {
// 先解析成 map[string]interface{}
var rawMap map[string]json.RawMessage
if err := common.Unmarshal(data, &rawMap); err != nil {
return err
}
// 用 struct tag 获取所有已定义字段名
knownFields := GetJSONFieldNames(reflect.TypeOf(*i))
// 再正常解析已定义字段
type Alias ImageRequest
var known Alias
if err := common.Unmarshal(data, &known); err != nil {
return err
}
*i = ImageRequest(known)
// 提取多余字段
i.Extra = make(map[string]json.RawMessage)
for k, v := range rawMap {
if _, ok := knownFields[k]; !ok {
i.Extra[k] = v
}
}
return nil
}
func GetJSONFieldNames(t reflect.Type) map[string]struct{} {
fields := make(map[string]struct{})
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// 跳过匿名字段(例如 ExtraFields
if field.Anonymous {
continue
}
tag := field.Tag.Get("json")
if tag == "-" || tag == "" {
continue
}
// 取逗号前字段名(排除 omitempty 等)
name := tag
if commaIdx := indexComma(tag); commaIdx != -1 {
name = tag[:commaIdx]
}
fields[name] = struct{}{}
}
return fields
}
func indexComma(s string) int {
for i := 0; i < len(s); i++ {
if s[i] == ',' {
return i
}
}
return -1
}
func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
var sizeRatio = 1.0
var qualityRatio = 1.0
if strings.HasPrefix(i.Model, "dall-e") {
// Size
if i.Size == "256x256" {
sizeRatio = 0.4
} else if i.Size == "512x512" {
sizeRatio = 0.45
} else if i.Size == "1024x1024" {
sizeRatio = 1
} else if i.Size == "1024x1792" || i.Size == "1792x1024" {
sizeRatio = 2
}
if i.Model == "dall-e-3" && i.Quality == "hd" {
qualityRatio = 2.0
if i.Size == "1024x1792" || i.Size == "1792x1024" {
qualityRatio = 1.5
}
}
}
// not support token count for dalle
return &types.TokenCountMeta{
CombineText: i.Prompt,
MaxTokens: 1584,
ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
}
}
func (i *ImageRequest) IsStream(c *gin.Context) bool {
return false
}
func (i *ImageRequest) SetModelName(modelName string) {
if modelName != "" {
i.Model = modelName
}
}
type ImageResponse struct {
Data []ImageData `json:"data"`
Created int64 `json:"created"`
Extra any `json:"extra,omitempty"`
}
type ImageData struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
RevisedPrompt string `json:"revised_prompt"`
}