diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index f2fb1b776..daffff180 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -6,6 +6,7 @@ import ( "io" "net/http" + "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" "github.com/QuantumNous/new-api/relay/channel/openai" @@ -35,8 +36,27 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - adaptor := openai.Adaptor{} - return adaptor.ConvertImageRequest(c, info, request) + // 解析extra到SFImageRequest里,以填入SiliconFlow特殊字段。若失败重建一个空的。 + sfRequest := &SFImageRequest{} + extra, err := common.Marshal(request.Extra) + if err == nil { + err = common.Unmarshal(extra, sfRequest) + if err != nil { + sfRequest = &SFImageRequest{} + } + } + + sfRequest.Model = request.Model + sfRequest.Prompt = request.Prompt + // 优先使用image_size/batch_size,否则使用OpenAI标准的size/n + if sfRequest.ImageSize == "" { + sfRequest.ImageSize = request.Size + } + if sfRequest.BatchSize == 0 { + sfRequest.BatchSize = request.N + } + + return sfRequest, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -51,6 +71,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } else if info.RelayMode == constant.RelayModeCompletions { return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil + } else if info.RelayMode == constant.RelayModeImagesGenerations { + return fmt.Sprintf("%s/v1/images/generations", info.ChannelBaseUrl), nil } return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil } @@ -102,6 +124,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom fallthrough case constant.RelayModeChatCompletions: fallthrough + case constant.RelayModeImagesGenerations: + fallthrough default: if info.IsStream { usage, err = openai.OaiStreamHandler(c, info, resp) diff --git a/relay/channel/siliconflow/dto.go b/relay/channel/siliconflow/dto.go index f075542c0..100975107 100644 --- a/relay/channel/siliconflow/dto.go +++ b/relay/channel/siliconflow/dto.go @@ -15,3 +15,18 @@ type SFRerankResponse struct { Results []dto.RerankResponseResult `json:"results"` Meta SFMeta `json:"meta"` } + +type SFImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + ImageSize string `json:"image_size,omitempty"` + BatchSize uint `json:"batch_size,omitempty"` + Seed uint64 `json:"seed,omitempty"` + NumInferenceSteps uint `json:"num_inference_steps,omitempty"` + GuidanceScale float64 `json:"guidance_scale,omitempty"` + Cfg float64 `json:"cfg,omitempty"` + Image string `json:"image,omitempty"` + Image2 string `json:"image2,omitempty"` + Image3 string `json:"image3,omitempty"` +}